CUDAStream.h 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #pragma once
  2. #include <cstdint>
  3. #include <utility>
  4. #include <cuda_runtime_api.h>
  5. #include <c10/core/DeviceGuard.h>
  6. #include <c10/core/Stream.h>
  7. #include <c10/cuda/CUDAFunctions.h>
  8. #include <c10/util/Exception.h>
  9. /*
  10. * Stream pool note.
  11. *
  12. * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
  13. * are backed by cuStreams, but they use several pools to minimize the costs
  14. * associated with creating, retaining, and destroying cuStreams.
  15. *
  16. * There are three pools per device, and a device's pools are lazily created.
  17. *
  18. * The first pool contains only the default stream. When the default stream
  19. * is requested it's returned.
  20. *
  21. * The second pool is the "low priority" or "default priority" streams. In
  22. * HIP builds there is no distinction between streams in this pool and streams
  23. * in the third pool (below). There are 32 of these streams per device, and
  24. * when a stream is requested one of these streams is returned round-robin.
  25. * That is, the first stream requested is at index 0, the second at index 1...
  26. * to index 31, then index 0 again.
  27. *
  28. * This means that if 33 low priority streams are requested, the first and
  29. * last streams requested are actually the same stream (under the covers)
  30. * and kernels enqueued on them cannot run concurrently.
  31. *
  32. * The third pool is the "high priority" streams. The third pool acts like
  33. * the second pool except the streams are created with a higher priority.
  34. *
  35. * These pools suggest that stream users should prefer many short-lived streams,
  36. * as the cost of acquiring and releasing streams is effectively zero. If
  37. * many longer-lived streams are required in performance critical scenarios
  38. * then the functionality here may need to be extended to allow, for example,
  39. * "reserving" a subset of the pool so that other streams do not accidentally
  40. * overlap the performance critical streams.
  41. *
  42. * Note: although the notion of "current stream for device" is thread local
  43. * (every OS thread has a separate current stream, as one might expect),
  44. * the stream pool is global across all threads; stream 0 is always stream 0
  45. * no matter which thread you use it on. Multiple threads can synchronize
  46. * on the same stream. Although the CUDA documentation is not very clear
  47. * on the matter, streams are thread safe; e.g., it is safe to enqueue
  48. * a kernel on the same stream from two different threads.
  49. */
  50. namespace c10 {
  51. namespace cuda {
  52. // Value object representing a CUDA stream. This is just a wrapper
  53. // around c10::Stream, but it comes with a little extra CUDA-specific
  54. // functionality (conversion to cudaStream_t), and a guarantee that
  55. // the wrapped c10::Stream really is a CUDA stream.
  56. class C10_CUDA_API CUDAStream {
  57. public:
  58. enum Unchecked { UNCHECKED };
  59. /// Construct a CUDAStream from a Stream. This construction is checked,
  60. /// and will raise an error if the Stream is not, in fact, a CUDA stream.
  61. explicit CUDAStream(Stream stream) : stream_(stream) {
  62. TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
  63. }
  64. /// Construct a CUDAStream from a Stream with no error checking.
  65. /// This constructor uses the "named" constructor idiom, and can
  66. /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream)
  67. explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {}
  68. bool operator==(const CUDAStream& other) const noexcept {
  69. return unwrap() == other.unwrap();
  70. }
  71. bool operator!=(const CUDAStream& other) const noexcept {
  72. return unwrap() != other.unwrap();
  73. }
  74. /// Implicit conversion to cudaStream_t.
  75. operator cudaStream_t() const {
  76. return stream();
  77. }
  78. /// Implicit conversion to Stream (a.k.a., forget that the stream is a
  79. /// CUDA stream).
  80. operator Stream() const {
  81. return unwrap();
  82. }
  83. /// Used to avoid baking in device type explicitly to Python-side API.
  84. DeviceType device_type() const {
  85. return DeviceType::CUDA;
  86. }
  87. /// Get the CUDA device index that this stream is associated with.
  88. DeviceIndex device_index() const {
  89. return stream_.device_index();
  90. }
  91. /// Get the full Device that this stream is associated with. The Device
  92. /// is guaranteed to be a CUDA device.
  93. Device device() const {
  94. return Device(DeviceType::CUDA, device_index());
  95. }
  96. /// Return the stream ID corresponding to this particular stream.
  97. StreamId id() const {
  98. return stream_.id();
  99. }
  100. bool query() const {
  101. DeviceGuard guard{stream_.device()};
  102. cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream()));
  103. if (err == cudaSuccess) {
  104. return true;
  105. } else if (err != cudaErrorNotReady) {
  106. C10_CUDA_CHECK(err);
  107. } else {
  108. // ignore and clear the error if not ready
  109. (void)cudaGetLastError();
  110. }
  111. return false;
  112. }
  113. void synchronize() const {
  114. DeviceGuard guard{stream_.device()};
  115. c10::cuda::stream_synchronize(stream());
  116. }
  117. int priority() const {
  118. DeviceGuard guard{stream_.device()};
  119. int priority = 0;
  120. C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
  121. return priority;
  122. }
  123. /// Explicit conversion to cudaStream_t.
  124. cudaStream_t stream() const;
  125. /// Explicit conversion to Stream.
  126. Stream unwrap() const {
  127. return stream_;
  128. }
  129. /// Reversibly pack a CUDAStream into a struct representation.
  130. /// Previously the stream's data was packed into a single int64_t,
  131. /// as it was assumed the fields would not require more than
  132. /// 64 bits of storage in total.
  133. /// See https://github.com/pytorch/pytorch/issues/75854
  134. /// for more information regarding newer platforms that may violate
  135. /// this assumption.
  136. ///
  137. /// The CUDAStream can be unpacked using unpack().
  138. struct c10::StreamData3 pack3() const {
  139. return stream_.pack3();
  140. }
  141. // Unpack a CUDAStream from the 3 fields generated by pack().
  142. static CUDAStream unpack3(
  143. StreamId stream_id,
  144. DeviceIndex device_index,
  145. DeviceType device_type) {
  146. return CUDAStream(Stream::unpack3(stream_id, device_index, device_type));
  147. }
  148. static std::tuple<int, int> priority_range() {
  149. // Note: this returns the range of priority **supported by PyTorch**, not
  150. // the range of priority **supported by CUDA**. The former is a subset of
  151. // the latter. Currently PyTorch only supports 0 and -1, which are "low" and
  152. // "high" priority.
  153. int least_priority, greatest_priority;
  154. C10_CUDA_CHECK(
  155. cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
  156. TORCH_INTERNAL_ASSERT(
  157. least_priority >= 0, "Unexpected CUDA stream priority range");
  158. TORCH_INTERNAL_ASSERT(
  159. greatest_priority <= -1, "Unexpected CUDA stream priority range");
  160. return std::make_tuple(0, -1);
  161. }
  162. // Deleted for now; use CUDAEvent::block instead
  163. // void synchronize_with(const CUDAEvent& event) const;
  164. private:
  165. Stream stream_;
  166. };
  167. /**
  168. * Get a new stream from the CUDA stream pool. You can think of this
  169. * as "creating" a new stream, but no such creation actually happens;
  170. * instead, streams are preallocated from the pool and returned in a
  171. * round-robin fashion.
  172. *
  173. * You can request a stream from the high priority pool by setting
  174. * isHighPriority to true, or a stream for a specific device by setting device
  175. * (defaulting to the current CUDA stream.)
  176. */
  177. C10_API CUDAStream
  178. getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
  179. /**
  180. * Get a CUDAStream from a externally allocated one.
  181. *
  182. * This is mainly for interoperability with different libraries where we
  183. * want to operate on a non-torch allocated stream for data exchange or similar
  184. * purposes
  185. */
  186. C10_API CUDAStream
  187. getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
  188. /**
  189. * Get the default CUDA stream, for the passed CUDA device, or for the
  190. * current device if no device index is passed. The default stream is
  191. * where most computation occurs when you aren't explicitly using
  192. * streams.
  193. */
  194. C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1);
  195. /**
  196. * Get the current CUDA stream, for the passed CUDA device, or for the
  197. * current device if no device index is passed. The current CUDA stream
  198. * will usually be the default CUDA stream for the device, but it may
  199. * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
  200. * or 'CUDAStreamGuard'.
  201. */
  202. C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
  203. /**
  204. * Set the current stream on the device of the passed in stream to be
  205. * the passed in stream. Yes, you read that right: this function
  206. * has *nothing* to do with the current device: it toggles the current
  207. * stream of the device of the passed stream.
  208. *
  209. * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead
  210. * (which will switch both your current device and current stream in the way you
  211. * expect, and reset it back to its original state afterwards).
  212. */
  213. C10_API void setCurrentCUDAStream(CUDAStream stream);
  214. C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
  215. } // namespace cuda
  216. } // namespace c10
  217. namespace std {
  218. template <>
  219. struct hash<c10::cuda::CUDAStream> {
  220. size_t operator()(c10::cuda::CUDAStream s) const noexcept {
  221. return std::hash<c10::Stream>{}(s.unwrap());
  222. }
  223. };
  224. } // namespace std