HIPGuardImplMasqueradingAsCUDA.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #pragma once
  2. #include <ATen/hip/HIPConfig.h>
  3. // The includes of HIPGuard.h
  4. #include <c10/hip/impl/HIPGuardImpl.h>
  5. #include <c10/hip/HIPMacros.h>
  6. #include <c10/core/DeviceType.h>
  7. #include <c10/core/impl/InlineDeviceGuard.h>
  8. #include <c10/core/impl/InlineStreamGuard.h>
  9. #include <c10/util/Exception.h>
  10. #include <c10/hip/impl/HIPGuardImpl.h>
  11. #include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
  12. #include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
  13. // Use of c10::hip namespace here makes hipification easier, because
  14. // I don't have to also fix namespaces. Sorry!
  15. namespace c10 { namespace hip {
  16. // Note [Masquerading as CUDA]
  17. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~
  18. // c10_hip is very easy to understand: it is HIPified from c10_cuda,
  19. // and anywhere you said CUDA, the source code now says HIP. HIPified
  20. // PyTorch is much harder to understand: it is HIPified from regular
  21. // PyTorch, yes, but NO source-to-source translation from CUDA to
  22. // HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
  23. // For example, when you use HIPified PyTorch, you say x.cuda() to
  24. // move a tensor onto ROCm device. We call this situation "HIP
  25. // masquerading as CUDA".
  26. //
  27. // This leads to a very awkward situation when we want to call c10_hip
  28. // code from PyTorch, since c10_hip is expecting things to be called
  29. // HIP, but PyTorch is calling them CUDA (masquerading as HIP). To
  30. // fix this impedance mismatch, we have MasqueradingAsCUDA variants
  31. // for all c10_hip classes. These translate between the "HIP" and "CUDA
  32. // masquerading as HIP" worlds. For example,
  33. // HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
  34. // HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
  35. // returns CUDA, getDevice() reports the current HIP device as a CUDA
  36. // device.)
  37. //
  38. // We should be able to delete all of these classes entirely once
  39. // we switch PyTorch to calling a HIP a HIP.
  40. //
  41. // When you add a new MasqueradingAsCUDA class/function, you need to
  42. // also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
  43. //
  44. //
  45. //
  46. // By the way, note that the cpp file associated with this also
  47. // *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
  48. // this HIP implementation.
  49. struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
  50. static constexpr DeviceType static_type = DeviceType::CUDA;
  51. HIPGuardImplMasqueradingAsCUDA() {}
  52. HIPGuardImplMasqueradingAsCUDA(DeviceType t) {
  53. TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
  54. }
  55. DeviceType type() const override {
  56. return DeviceType::CUDA;
  57. }
  58. Device exchangeDevice(Device d) const override {
  59. TORCH_INTERNAL_ASSERT(d.is_cuda());
  60. Device old_device = getDevice();
  61. if (old_device.index() != d.index()) {
  62. C10_HIP_CHECK(hipSetDevice(d.index()));
  63. }
  64. return old_device;
  65. }
  66. Device getDevice() const override {
  67. int device;
  68. C10_HIP_CHECK(hipGetDevice(&device));
  69. return Device(DeviceType::CUDA, device);
  70. }
  71. void setDevice(Device d) const override {
  72. TORCH_INTERNAL_ASSERT(d.is_cuda());
  73. C10_HIP_CHECK(hipSetDevice(d.index()));
  74. }
  75. void uncheckedSetDevice(Device d) const noexcept override {
  76. C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
  77. }
  78. Stream getStream(Device d) const noexcept override {
  79. return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
  80. }
  81. Stream getDefaultStream(Device d) const override {
  82. return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
  83. }
  84. Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
  85. return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
  86. }
  87. Stream exchangeStream(Stream s) const noexcept override {
  88. HIPStreamMasqueradingAsCUDA cs(s);
  89. auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
  90. setCurrentHIPStreamMasqueradingAsCUDA(cs);
  91. return old_stream.unwrap();
  92. }
  93. DeviceIndex deviceCount() const noexcept override {
  94. int deviceCnt;
  95. hipError_t _err;
  96. _err = hipGetDeviceCount(&deviceCnt);
  97. #if defined(USE_ROCM) && (ROCM_VERSION < 50201)
  98. if(_err == hipErrorInvalidDevice)
  99. return 0;
  100. #endif
  101. if(_err != hipErrorNoDevice && _err != hipSuccess)
  102. C10_HIP_CHECK(_err);
  103. return deviceCnt;
  104. }
  105. // Event-related functions
  106. // Note: hipEventCreateWithFlags should be called on the same device as
  107. // the recording stream's device.
  108. void createEvent(
  109. hipEvent_t* hip_event,
  110. const EventFlag flag) const {
  111. // Maps PyTorch's Event::Flag to HIP flag
  112. auto hip_flag = hipEventDefault;
  113. switch (flag) {
  114. case EventFlag::PYTORCH_DEFAULT:
  115. case EventFlag::HIP_EVENT_DISABLE_TIMING:
  116. hip_flag = hipEventDisableTiming;
  117. break;
  118. case EventFlag::BACKEND_DEFAULT:
  119. case EventFlag::HIP_EVENT_DEFAULT:
  120. hip_flag = hipEventDefault;
  121. break;
  122. default:
  123. TORCH_CHECK(false, "HIP event received unknown flag");
  124. }
  125. C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
  126. }
  127. void destroyEvent(
  128. void* event,
  129. const DeviceIndex device_index) const noexcept override {
  130. if (!event) return;
  131. auto hip_event = static_cast<hipEvent_t>(event);
  132. int orig_device;
  133. C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
  134. C10_HIP_CHECK_WARN(hipSetDevice(device_index));
  135. C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
  136. C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
  137. }
  138. void record(void** event,
  139. const Stream& stream,
  140. const DeviceIndex device_index,
  141. const EventFlag flag) const override {
  142. TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
  143. "Event device index ",
  144. device_index,
  145. " does not match recording stream's device index ",
  146. stream.device_index(),
  147. ".");
  148. hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
  149. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  150. // Moves to stream's device to record
  151. const auto orig_device = getDevice();
  152. setDevice(stream.device());
  153. // Creates the event (lazily)
  154. if (!hip_event) createEvent(&hip_event, flag);
  155. C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
  156. // Makes the void* point to the (possibly just allocated) HIP event
  157. *event = hip_event;
  158. // Resets device
  159. setDevice(orig_device);
  160. }
  161. void block(
  162. void* event,
  163. const Stream& stream) const override {
  164. if (!event) return;
  165. hipEvent_t hip_event = static_cast<hipEvent_t>(event);
  166. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  167. const auto orig_device = getDevice();
  168. setDevice(stream.device());
  169. C10_HIP_CHECK(hipStreamWaitEvent(
  170. hip_stream,
  171. hip_event,
  172. /*flags (must be zero)=*/ 0));
  173. setDevice(orig_device);
  174. }
  175. bool queryEvent(void* event) const override {
  176. if (!event) return true;
  177. hipEvent_t hip_event = static_cast<hipEvent_t>(event);
  178. const hipError_t err = hipEventQuery(hip_event);
  179. if (err != hipErrorNotReady) C10_HIP_CHECK(err);
  180. else {
  181. // ignore and clear the error if not ready
  182. hipGetLastError();
  183. }
  184. return (err == hipSuccess);
  185. }
  186. // Stream-related functions
  187. bool queryStream(const Stream& stream) const override {
  188. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  189. return hip_stream.query();
  190. }
  191. void synchronizeStream(const Stream& stream) const override {
  192. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  193. hip_stream.synchronize();
  194. }
  195. void recordDataPtrOnStream(
  196. const c10::DataPtr& data_ptr,
  197. const Stream& stream) const override {
  198. HIPStreamMasqueradingAsCUDA hip_stream{stream};
  199. HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
  200. }
  201. };
  202. // All of the guards which have HIPGuardImpl burned in need to also have
  203. // variants using HIPGuardImplMasqueradingAsCUDA.
  204. /// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
  205. /// the correct InlineDeviceGuard burned in. Sorry about the
  206. /// copy-pasting.
  207. struct HIPGuardMasqueradingAsCUDA {
  208. explicit HIPGuardMasqueradingAsCUDA() = delete;
  209. explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
  210. explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
  211. HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
  212. HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
  213. HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
  214. HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
  215. void set_device(Device device) { guard_.set_device(device); }
  216. void reset_device(Device device) { guard_.reset_device(device); }
  217. void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
  218. Device original_device() const { return guard_.original_device(); }
  219. Device current_device() const { return guard_.current_device(); }
  220. private:
  221. c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  222. };
  223. struct OptionalHIPGuardMasqueradingAsCUDA {
  224. explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
  225. explicit OptionalHIPGuardMasqueradingAsCUDA(optional<Device> device_opt) : guard_(device_opt) {}
  226. explicit OptionalHIPGuardMasqueradingAsCUDA(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
  227. OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
  228. OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
  229. OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
  230. OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
  231. void set_device(Device device) { guard_.set_device(device); }
  232. void reset_device(Device device) { guard_.reset_device(device); }
  233. void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
  234. optional<Device> original_device() const { return guard_.original_device(); }
  235. optional<Device> current_device() const { return guard_.current_device(); }
  236. void reset() { guard_.reset(); }
  237. private:
  238. c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  239. };
  240. struct HIPStreamGuardMasqueradingAsCUDA {
  241. explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
  242. explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
  243. HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
  244. HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
  245. HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  246. HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  247. void reset_stream(Stream stream) { guard_.reset_stream(stream); }
  248. HIPStreamMasqueradingAsCUDA original_stream() const {
  249. return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
  250. }
  251. HIPStreamMasqueradingAsCUDA current_stream() const {
  252. return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
  253. }
  254. Device current_device() const { return guard_.current_device(); }
  255. Device original_device() const { return guard_.original_device(); }
  256. private:
  257. c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  258. };
  259. struct OptionalHIPStreamGuardMasqueradingAsCUDA {
  260. explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
  261. explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
  262. explicit OptionalHIPStreamGuardMasqueradingAsCUDA(optional<Stream> stream_opt) : guard_(stream_opt) {}
  263. OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
  264. OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
  265. OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  266. OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
  267. void reset_stream(Stream stream) { guard_.reset_stream(stream); }
  268. optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
  269. auto r = guard_.original_stream();
  270. if (r.has_value()) {
  271. return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
  272. } else {
  273. return nullopt;
  274. }
  275. }
  276. optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
  277. auto r = guard_.current_stream();
  278. if (r.has_value()) {
  279. return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value()));
  280. } else {
  281. return nullopt;
  282. }
  283. }
  284. void reset() { guard_.reset(); }
  285. private:
  286. c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  287. };
  288. struct HIPMultiStreamGuardMasqueradingAsCUDA {
  289. explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
  290. : guard_(unwrapStreams(streams)) {}
  291. HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
  292. HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
  293. HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
  294. HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
  295. private:
  296. c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
  297. static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
  298. std::vector<Stream> streams;
  299. streams.reserve(hipStreams.size());
  300. for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
  301. streams.push_back(hipStream);
  302. }
  303. return streams;
  304. }
  305. };
  306. }} // namespace c10::hip