HIPStreamMasqueradingAsCUDA.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #pragma once
  2. #include <c10/hip/HIPStream.h>
  3. // Use of c10::hip namespace here makes hipification easier, because
  4. // I don't have to also fix namespaces. Sorry!
  5. namespace c10 { namespace hip {
  6. // See Note [Masquerading as CUDA] for motivation
  7. class HIPStreamMasqueradingAsCUDA {
  8. public:
  9. enum Unchecked { UNCHECKED };
  10. explicit HIPStreamMasqueradingAsCUDA(Stream stream)
  11. : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
  12. // We did the coercion unchecked; check that it was right.
  13. TORCH_CHECK(stream.device().is_cuda() /* !!! */);
  14. }
  15. explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
  16. // Unsafely coerce the "CUDA" stream into a HIP stream
  17. : stream_(
  18. HIPStream(
  19. Stream(
  20. Stream::UNSAFE,
  21. Device(DeviceType::HIP, stream.device_index()),
  22. stream.id())
  23. )
  24. ) {}
  25. // New constructor, just for this. Does NOT coerce.
  26. explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
  27. bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
  28. return stream_ == other.stream_;
  29. }
  30. bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
  31. return stream_ != other.stream_;
  32. }
  33. operator hipStream_t() const { return stream_.stream(); }
  34. operator Stream() const {
  35. // Unsafely coerce HIP stream into a "CUDA" stream
  36. return Stream(Stream::UNSAFE, device(), id());
  37. }
  38. DeviceIndex device_index() const { return stream_.device_index(); }
  39. // Unsafely coerce HIP device into CUDA device
  40. DeviceType device_type() const { return DeviceType::CUDA; }
  41. Device device() const {
  42. // Unsafely coerce HIP device into CUDA device
  43. return Device(DeviceType::CUDA, stream_.device_index());
  44. }
  45. StreamId id() const { return stream_.id(); }
  46. bool query() const { return stream_.query(); }
  47. void synchronize() const { stream_.synchronize(); }
  48. int priority() const { return stream_.priority(); }
  49. hipStream_t stream() const { return stream_.stream(); }
  50. Stream unwrap() const {
  51. // Unsafely coerce HIP stream into "CUDA" stream
  52. return Stream(Stream::UNSAFE, device(), id());
  53. }
  54. c10::StreamData3 pack3() const noexcept {
  55. // Unsafely coerce HIP stream into "CUDA" stream before packing
  56. return unwrap().pack3();
  57. }
  58. static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
  59. DeviceIndex device_index,
  60. DeviceType device_type) {
  61. // NB: constructor manages CUDA->HIP translation for us
  62. return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
  63. stream_id, device_index, device_type));
  64. }
  65. static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
  66. // New method, gets the underlying HIPStream
  67. HIPStream hip_stream() const { return stream_; }
  68. private:
  69. HIPStream stream_;
  70. };
  71. HIPStreamMasqueradingAsCUDA
  72. inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
  73. return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
  74. }
  75. HIPStreamMasqueradingAsCUDA
  76. inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
  77. return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
  78. }
  79. inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
  80. return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
  81. }
  82. inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
  83. return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
  84. }
  85. inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
  86. setCurrentHIPStream(stream.hip_stream());
  87. }
  88. inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
  89. stream << s.hip_stream() << " (masquerading as CUDA)";
  90. return stream;
  91. }
  92. }} // namespace c10::hip
  93. namespace std {
  94. template <>
  95. struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
  96. size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
  97. return std::hash<c10::Stream>{}(s.unwrap());
  98. }
  99. };
  100. } // namespace std