Context.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #pragma once
  2. #include <ATen/CPUGeneratorImpl.h>
  3. #include <ATen/LinalgBackend.h>
  4. #include <ATen/core/ATenGeneral.h>
  5. #include <ATen/core/DeprecatedTypeProperties.h>
  6. #include <ATen/core/Generator.h>
  7. #include <ATen/core/LegacyTypeDispatch.h>
  8. #include <ATen/detail/CUDAHooksInterface.h>
  9. #include <ATen/detail/HIPHooksInterface.h>
  10. #include <ATen/detail/MPSHooksInterface.h>
  11. #include <ATen/detail/ORTHooksInterface.h>
  12. #include <c10/core/QEngine.h>
  13. #include <c10/core/impl/DeviceGuardImplInterface.h>
  14. #include <c10/util/CallOnce.h>
  15. #include <c10/util/Exception.h>
  16. #include <c10/util/env.h>
  17. #include <c10/util/irange.h>
  18. #include <cstdint>
  19. #include <memory>
  20. #include <mutex>
  21. namespace at {
  22. class Tensor;
  23. enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
  24. class TORCH_API Context {
  25. public:
  26. Context();
  27. const Generator& defaultGenerator(Device device) {
  28. DeviceType device_type = device.type();
  29. initCUDAIfNeeded(device_type);
  30. initHIPIfNeeded(device_type);
  31. if (device_type == at::kCPU) {
  32. return at::detail::getDefaultCPUGenerator();
  33. } else if (device_type == at::kCUDA) {
  34. return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
  35. } else if (device_type == at::kMPS) {
  36. return at::detail::getMPSHooks().getDefaultMPSGenerator();
  37. } else {
  38. AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
  39. }
  40. }
  41. Device getDeviceFromPtr(void* data, DeviceType device_type) {
  42. initCUDAIfNeeded(device_type);
  43. initHIPIfNeeded(device_type);
  44. if (device_type == at::kCPU) {
  45. return DeviceType::CPU;
  46. } else if (device_type == at::kCUDA) {
  47. return at::detail::getCUDAHooks().getDeviceFromPtr(data);
  48. } else {
  49. AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
  50. }
  51. }
  52. static bool isPinnedPtr(void* data) {
  53. return detail::getCUDAHooks().isPinnedPtr(data);
  54. }
  55. static bool hasOpenMP();
  56. static bool hasMKL();
  57. static bool hasLAPACK();
  58. static bool hasMKLDNN();
  59. static bool hasMAGMA() {
  60. return detail::getCUDAHooks().hasMAGMA();
  61. }
  62. static bool hasCUDA() {
  63. return detail::getCUDAHooks().hasCUDA();
  64. }
  65. static bool hasCUDART() {
  66. return detail::getCUDAHooks().hasCUDART();
  67. }
  68. static long versionCUDART() {
  69. return detail::getCUDAHooks().versionCUDART();
  70. }
  71. static bool hasCuDNN() {
  72. return detail::getCUDAHooks().hasCuDNN();
  73. }
  74. static long versionCuDNN() {
  75. return detail::getCUDAHooks().versionCuDNN();
  76. }
  77. static bool hasCuSOLVER() {
  78. return detail::getCUDAHooks().hasCuSOLVER();
  79. }
  80. static bool hasHIP() {
  81. return detail::getHIPHooks().hasHIP();
  82. }
  83. static bool hasMPS() {
  84. return detail::getMPSHooks().hasMPS();
  85. }
  86. static bool hasIPU() {
  87. return c10::impl::hasDeviceGuardImpl(at::DeviceType::IPU);
  88. }
  89. static bool hasXLA() {
  90. return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
  91. }
  92. static bool hasLazy() {
  93. return c10::impl::hasDeviceGuardImpl(at::DeviceType::Lazy);
  94. }
  95. static bool hasORT() {
  96. return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
  97. }
  98. // defined in header so that getNonVariableType has ability to inline
  99. // call_once check. getNonVariableType is called fairly frequently
  100. void lazyInitCUDA() {
  101. c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
  102. }
  103. void lazyInitHIP() {
  104. c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
  105. }
  106. static const at::cuda::NVRTC& getNVRTC() {
  107. return detail::getCUDAHooks().nvrtc();
  108. }
  109. static bool setFlushDenormal(bool on);
  110. // NB: This method is *purely* whether or not a user requested
  111. // that CuDNN was enabled, it doesn't actually say anything about
  112. // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
  113. // to test this instead
  114. bool userEnabledCuDNN() const;
  115. void setUserEnabledCuDNN(bool e);
  116. bool userEnabledMkldnn() const;
  117. void setUserEnabledMkldnn(bool e);
  118. bool benchmarkCuDNN() const;
  119. void setBenchmarkCuDNN(bool);
  120. int benchmarkLimitCuDNN() const;
  121. void setBenchmarkLimitCuDNN(int);
  122. bool deterministicCuDNN() const;
  123. void setDeterministicCuDNN(bool);
  124. // Note [Disabling Fused SDP Kernels]
  125. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  126. // Flash and Memory Efficient SDP kernels are enabled by default.
  127. // However, they can be disabled by setting
  128. // at::globalContext().setUserEnabledFlashSDP(false) flag.
  129. // This is useful for debugging purposes. For example, if you want to
  130. // compare the performance of the flash SDP kernels with the unfused
  131. // kernel, you can disable the flash SDP kernels. By disabling
  132. // the math SDP kernel, you can force your code to use flash kernels.
  133. // The math SDP kernel can be disabled by setting
  134. // at::globalContext().setUserEnabledMathSDP(false) flag.
  135. void setSDPUseFlash(bool);
  136. bool userEnabledFlashSDP() const;
  137. void setSDPUseMemEfficient(bool);
  138. bool userEnabledMemEfficientSDP() const;
  139. void setSDPUseMath(bool);
  140. bool userEnabledMathSDP() const;
  141. at::LinalgBackend linalgPreferredBackend() const;
  142. void setLinalgPreferredBackend(at::LinalgBackend);
  143. // Note [Enabling Deterministic Operations]
  144. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  145. // Operations in PyTorch that normally act nondeterministically, but have an
  146. // alternate deterministic implementation, should satisfy the following
  147. // requirements:
  148. //
  149. // * Include this comment: "See Note [Enabling Deterministic Operations]"
  150. //
  151. // * Check the value of `at::globalContext().deterministicAlgorithms()` to
  152. // toggle
  153. // between nondeterministic and deterministic implementations.
  154. //
  155. // * Have an entry in the list of PyTorch operations that toggle between
  156. // nondeterministic
  157. // and deterministic implementations, in the docstring of
  158. // `use_deterministic_algorithms()` in torch/__init__.py
  159. //
  160. // `example_func()` below shows an example of toggling between
  161. // nondeterministic and deterministic implementations:
  162. //
  163. // void example_func() {
  164. // // See Note [Enabling Deterministic Operations]
  165. // if (at::globalContext().deterministicAlgorithms()) {
  166. // example_func_deterministic();
  167. // } else {
  168. // example_func_nondeterministic();
  169. // }
  170. // }
  171. bool deterministicAlgorithms() const;
  172. bool deterministicAlgorithmsWarnOnly() const;
  173. void setDeterministicAlgorithms(bool, bool);
  174. // Note [Writing Nondeterministic Operations]
  175. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  176. // Operations in PyTorch that act nondeterministically and do not have an
  177. // alternate deterministic implementation should satisfy the following
  178. // requirements:
  179. //
  180. // * Include this comment: "See Note [Writing Nondeterministic Operations]"
  181. //
  182. // * Include a comment explaining why the operation is nondeterministic.
  183. //
  184. // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
  185. // of the time, this should be accomplished by calling
  186. // `at::globalContext().alertNotDeterminstic()`. However, if the
  187. // nondeterministic behavior is caused by the CuBLAS workspace
  188. // configuration in CUDA >= 10.2,
  189. // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
  190. // called instead (in this case, a comment explaining why the operation is
  191. // nondeterministic is not necessary). See below for details on these
  192. // methods.
  193. //
  194. // * Have an entry in the list of nondeterministic PyTorch operations in the
  195. // docstring of `use_deterministic_algorithms()` in torch/__init__.py
  196. //
  197. // * Have a test function in `test/test_torch.py` whose name begins with
  198. // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
  199. // configuration is the reason for nondeterminism, the operation should be
  200. // included in the `test_cublas_config_nondeterministic_alert` test. Any new
  201. // tests should ideally follow a pattern similar to the existing ones.
  202. //
  203. // `example_func()` below shows an example of the comments and error-throwing
  204. // code for a nondeterministic operation:
  205. //
  206. // void example_func() {
  207. // // See Note [Writing Nondeterministic Operations]
  208. // // Nondeterministic because <reason>
  209. // at::globalContext().alertNondeterministic("example_func");
  210. // ...
  211. // }
  212. // Throws an error if `Context::deterministicAlgorithms()` is true
  213. static void alertNotDeterministic(c10::string_view const& caller);
  214. // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
  215. // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
  216. // ":4096:8". For more details:
  217. // https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
  218. void alertCuBLASConfigNotDeterministic() const;
  219. void setFloat32MatmulPrecision(const std::string& s);
  220. bool allowTF32CuDNN() const;
  221. void setAllowTF32CuDNN(bool);
  222. bool allowTF32CuBLAS() const;
  223. void setAllowTF32CuBLAS(bool);
  224. Float32MatmulPrecision float32MatmulPrecision() const;
  225. void setFloat32MatmulPrecision(Float32MatmulPrecision p);
  226. bool allowFP16ReductionCuBLAS() const;
  227. void setAllowFP16ReductionCuBLAS(bool);
  228. bool allowBF16ReductionCuBLAS() const;
  229. void setAllowBF16ReductionCuBLAS(bool);
  230. at::QEngine qEngine() const;
  231. void setQEngine(at::QEngine e);
  232. static const std::vector<at::QEngine>& supportedQEngines();
  233. static bool isXNNPACKAvailable();
  234. void setCheckSparseTensorInvariants(bool e);
  235. bool checkSparseTensorInvariants() const;
  236. // This method is used to release the original weight after pre-packing.
  237. // It should be called once before loading/running the model.
  238. // NB: By default it is set to true for mobile builds.
  239. void setReleaseWeightsWhenPrepacking(bool e);
  240. bool releaseWeightsWhenPrepacking() const;
  241. void setDisplayVmapFallbackWarnings(bool enabled);
  242. bool areVmapFallbackWarningsEnabled() const;
  243. void setDefaultMobileCPUAllocator();
  244. void unsetDefaultMobileCPUAllocator();
  245. private:
  246. void initCUDAIfNeeded(DeviceType p) {
  247. if (p == DeviceType::CUDA) {
  248. lazyInitCUDA();
  249. }
  250. }
  251. void initHIPIfNeeded(DeviceType p) {
  252. if (p == DeviceType::HIP) {
  253. lazyInitHIP();
  254. }
  255. }
  256. static bool checkCuBLASConfigDeterministic();
  257. c10::once_flag thc_init;
  258. c10::once_flag thh_init;
  259. bool enabled_cudnn = true;
  260. bool deterministic_cudnn = false;
  261. bool _deterministic_algorithms = false;
  262. bool _deterministic_algorithms_warn_only = false;
  263. bool enabled_flashSDP = true;
  264. bool enabled_mem_efficientSDP = true;
  265. bool enabled_mathSDP = true;
  266. #ifdef USE_ROCM
  267. bool benchmark_cudnn = true;
  268. #else
  269. bool benchmark_cudnn = false;
  270. #endif
  271. Float32MatmulPrecision float32_matmul_precision =
  272. c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
  273. ? at::Float32MatmulPrecision::HIGH
  274. : at::Float32MatmulPrecision::HIGHEST;
  275. int benchmark_limit_cudnn = 10;
  276. bool allow_tf32_cudnn = true;
  277. bool allow_fp16_reduction_cublas = true;
  278. bool allow_bf16_reduction_cublas = true;
  279. bool enabled_mkldnn = true;
  280. at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
  281. #ifdef C10_MOBILE
  282. bool release_original_weights = true;
  283. #else
  284. bool release_original_weights = false;
  285. #endif
  286. bool display_vmap_fallback_warnings_ = false;
  287. c10::optional<at::QEngine> quantized_engine = c10::nullopt;
  288. bool enable_sparse_tensor_invariant_checks = false;
  289. Allocator* prev_allocator_ptr_{nullptr};
  290. };
  291. TORCH_API Context& globalContext();
  292. static inline void init() {
  293. globalContext();
  294. }
  295. TORCH_API Allocator* getCPUAllocator();
  296. static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
  297. Backend p,
  298. ScalarType s) {
  299. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  300. p, s);
  301. }
  302. static inline DeprecatedTypeProperties& CPU(ScalarType s) {
  303. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  304. Backend::CPU, s);
  305. }
  306. static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
  307. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  308. Backend::CUDA, s);
  309. }
  310. static inline DeprecatedTypeProperties& HIP(ScalarType s) {
  311. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  312. Backend::HIP, s);
  313. }
  314. static inline DeprecatedTypeProperties& MPS(ScalarType s) {
  315. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  316. Backend::MPS, s);
  317. }
  318. static inline bool hasCUDA() {
  319. return globalContext().hasCUDA();
  320. }
  321. static inline bool hasHIP() {
  322. return globalContext().hasHIP();
  323. }
  324. static inline bool hasIPU() {
  325. return globalContext().hasIPU();
  326. }
  327. static inline bool hasXLA() {
  328. return globalContext().hasXLA();
  329. }
  330. static inline bool hasMPS() {
  331. return globalContext().hasMPS();
  332. }
  333. static inline bool hasORT() {
  334. return globalContext().hasORT();
  335. }
  336. // Despite its name, this function returns the number of *CUDA* GPUs.
  337. static inline size_t getNumGPUs() {
  338. // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
  339. // FUNCTION. If you are interested in interrogating the number of
  340. // devices for a specific device type, add that function to the
  341. // relevant library (e.g., similar to at::cuda::device_count())
  342. if (hasCUDA() && hasHIP()) {
  343. throw std::runtime_error(
  344. "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
  345. "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
  346. "means HIP. Rebuild PyTorch with one or the other disabled.");
  347. } else if (hasCUDA()) {
  348. return detail::getCUDAHooks().getNumGPUs();
  349. } else if (hasHIP()) {
  350. return detail::getHIPHooks().getNumGPUs();
  351. } else {
  352. return 0;
  353. }
  354. }
  355. static inline bool hasOpenMP() {
  356. return globalContext().hasOpenMP();
  357. }
  358. static inline bool hasMKL() {
  359. return globalContext().hasMKL();
  360. }
  361. static inline bool hasLAPACK() {
  362. return globalContext().hasLAPACK();
  363. }
  364. static inline bool hasMAGMA() {
  365. return globalContext().hasMAGMA();
  366. }
  367. static inline bool hasMKLDNN() {
  368. return globalContext().hasMKLDNN();
  369. }
  370. static inline void manual_seed(uint64_t seed) {
  371. auto gen = globalContext().defaultGenerator(DeviceType::CPU);
  372. {
  373. // See Note [Acquire lock when using random generators]
  374. std::lock_guard<std::mutex> lock(gen.mutex());
  375. gen.set_current_seed(seed);
  376. }
  377. // NB: Sometimes we build with CUDA, but we don't have any GPUs
  378. // available. In that case, we must not seed CUDA; it will fail!
  379. const auto num_gpus = detail::getCUDAHooks().getNumGPUs();
  380. if (hasCUDA() && num_gpus > 0) {
  381. for (const auto i : c10::irange(num_gpus)) {
  382. auto cuda_gen = globalContext().defaultGenerator(
  383. Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
  384. {
  385. // See Note [Acquire lock when using random generators]
  386. std::lock_guard<std::mutex> lock(cuda_gen.mutex());
  387. cuda_gen.set_current_seed(seed);
  388. }
  389. }
  390. }
  391. if (hasMPS()) {
  392. auto mps_gen = globalContext().defaultGenerator(DeviceType::MPS);
  393. // See Note [Acquire lock when using random generators]
  394. std::lock_guard<std::mutex> lock(mps_gen.mutex());
  395. mps_gen.set_current_seed(seed);
  396. }
  397. }
  398. // When the global flag `allow_tf32` is set to true, cuBLAS handles are
  399. // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
  400. // For some operators, such as addmv, TF32 offers no performance improvement
  401. // but causes precision loss. To help this case, this class implements
  402. // a RAII guard that can be used to quickly disable TF32 within its scope.
  403. //
  404. // Usage:
  405. // NoTF32Guard disable_tf32;
  406. struct TORCH_API NoTF32Guard {
  407. NoTF32Guard();
  408. ~NoTF32Guard();
  409. static bool should_disable_tf32();
  410. private:
  411. bool changed = false;
  412. };
  413. #ifdef USE_ROCM
  414. struct TORCH_API ROCmBackwardPassGuard {
  415. ROCmBackwardPassGuard();
  416. ~ROCmBackwardPassGuard();
  417. static bool is_backward_pass();
  418. private:
  419. static thread_local bool is_backward_pass_;
  420. };
  421. #endif
  422. } // namespace at