autocast_mode.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. namespace at {
  4. namespace autocast {
  5. TORCH_API bool is_enabled();
  6. TORCH_API void set_enabled(bool enabled);
  7. TORCH_API void clear_cache();
  8. TORCH_API int increment_nesting();
  9. TORCH_API int decrement_nesting();
  10. TORCH_API bool is_cpu_enabled();
  11. TORCH_API void set_cpu_enabled(bool enabled);
  12. TORCH_API at::ScalarType get_autocast_gpu_dtype();
  13. TORCH_API at::ScalarType get_autocast_cpu_dtype();
  14. TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
  15. TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
  16. TORCH_API bool is_xpu_enabled();
  17. TORCH_API void set_xpu_enabled(bool enabled);
  18. TORCH_API at::ScalarType get_autocast_xpu_dtype();
  19. TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
  20. TORCH_API bool is_hpu_enabled();
  21. TORCH_API void set_hpu_enabled(bool enabled);
  22. TORCH_API at::ScalarType get_autocast_hpu_dtype();
  23. TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
  24. TORCH_API bool is_autocast_cache_enabled();
  25. TORCH_API void set_autocast_cache_enabled(bool enabled);
  26. namespace {
  27. bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
  28. switch (device_type) {
  29. case DeviceType::CUDA:
  30. return (tensor.is_cuda() || tensor.is_xla()) &&
  31. tensor.is_floating_point();
  32. case DeviceType::CPU:
  33. return (tensor.is_cpu() || tensor.is_mkldnn()) &&
  34. tensor.is_floating_point();
  35. case DeviceType::XPU:
  36. return tensor.is_xpu() && tensor.is_floating_point();
  37. case DeviceType::HPU:
  38. return tensor.is_hpu() && tensor.is_floating_point();
  39. default:
  40. return false;
  41. }
  42. }
  43. } // namespace
  44. inline DispatchKey get_autocast_dispatch_key_from_device_type(
  45. DeviceType device_type) {
  46. switch (device_type) {
  47. case DeviceType::CUDA:
  48. return DispatchKey::Autocast;
  49. case DeviceType::CPU:
  50. return DispatchKey::AutocastCPU;
  51. case DeviceType::XPU:
  52. return DispatchKey::AutocastXPU;
  53. case DeviceType::HPU:
  54. return DispatchKey::AutocastHPU;
  55. default:
  56. throw std::runtime_error(
  57. "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
  58. }
  59. }
  60. inline at::ScalarType get_lower_precision_fp_from_device_type(
  61. DeviceType device_type) {
  62. switch (device_type) {
  63. case DeviceType::CUDA:
  64. return get_autocast_gpu_dtype();
  65. case DeviceType::CPU:
  66. return get_autocast_cpu_dtype();
  67. case DeviceType::XPU:
  68. return get_autocast_xpu_dtype();
  69. case DeviceType::HPU:
  70. return get_autocast_hpu_dtype();
  71. default:
  72. throw std::runtime_error(
  73. "unknown device type for autocast in get_lower_precision_fp_from_device_type");
  74. }
  75. }
  76. /********************************************************************
  77. Logic to extract the promote type from any Tensor or TensorList args.
  78. ********************************************************************/
  79. // Overload to catch Tensor args.
  80. // If nextArg is floating-point, compare its scalar_type with our
  81. // current best guess for the promote type, and update if necessary.
  82. inline at::ScalarType prioritize(
  83. at::ScalarType current,
  84. const Tensor& nextArg,
  85. DeviceType device_type = DeviceType::CUDA) {
  86. if (current == at::kDouble) {
  87. AT_ERROR("promote type is double in at::autocast::prioritize");
  88. return current;
  89. }
  90. at::ScalarType lower_precision_fp =
  91. get_lower_precision_fp_from_device_type(device_type);
  92. if (is_autocast_eligible(nextArg, device_type)) {
  93. auto next = nextArg.scalar_type();
  94. if (next == at::kDouble) {
  95. return current; // ignores double tensors
  96. } else if (current == at::kFloat || next == at::kFloat) {
  97. return at::kFloat; // prioritizes float over lower_precision_fp
  98. } else if (current == lower_precision_fp && next == lower_precision_fp) {
  99. return lower_precision_fp;
  100. } else {
  101. AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
  102. return current;
  103. }
  104. } else {
  105. return current;
  106. }
  107. }
  108. // Overload to catch TensorList args (for e.g. cat, stack).
  109. // Reuses the overload above to process each Tensor in the list.
  110. inline at::ScalarType prioritize(
  111. at::ScalarType current,
  112. const TensorList& list,
  113. DeviceType device_type = DeviceType::CUDA) {
  114. for (const auto& tensor : list) {
  115. current = prioritize(current, tensor, device_type);
  116. }
  117. return current;
  118. }
  119. inline at::ScalarType prioritize(
  120. at::ScalarType current,
  121. const ITensorListRef& list,
  122. DeviceType device_type = DeviceType::CUDA) {
  123. for (const auto& tensor : list) {
  124. current = prioritize(current, tensor, device_type);
  125. }
  126. return current;
  127. }
  128. // Template to catch non-Tensor args (no-op that returns current best guess)
  129. template <typename T>
  130. inline at::ScalarType prioritize(
  131. at::ScalarType current,
  132. T nextArg,
  133. DeviceType device_type = DeviceType::CUDA) {
  134. return current;
  135. }
  136. // Overload for the tail case.
  137. inline at::ScalarType promote_type(
  138. at::ScalarType current,
  139. DeviceType device_type) {
  140. return current;
  141. }
  142. // Unpack args and determine if incoming lower_precision_fp tensors need to be
  143. // promoted to float32. Non-Tensor arguments are ignored.
  144. template <typename Arg0, typename... Args>
  145. inline at::ScalarType promote_type(
  146. at::ScalarType current,
  147. DeviceType device_type,
  148. Arg0 arg0,
  149. Args... args) {
  150. auto new_current = prioritize(current, arg0, device_type);
  151. return promote_type(new_current, device_type, args...);
  152. }
  153. /****************************************************
  154. Logic to apply cached casting to any Tensor argument.
  155. ****************************************************/
  156. inline bool is_eligible(
  157. const Tensor& arg,
  158. DeviceType device_type = DeviceType::CUDA) {
  159. return (
  160. arg.defined() && is_autocast_eligible(arg, device_type) &&
  161. (arg.scalar_type() != at::kDouble));
  162. }
  163. // Overload to catch Tensor args
  164. TORCH_API Tensor cached_cast(
  165. at::ScalarType to_type,
  166. const Tensor& arg,
  167. DeviceType device_type = DeviceType::CUDA);
  168. // Overload to process optional<Tensor>
  169. inline c10::optional<Tensor> cached_cast(
  170. at::ScalarType to_type,
  171. const c10::optional<Tensor>& arg,
  172. DeviceType device_type = DeviceType::CUDA) {
  173. if (arg.has_value()) {
  174. return cached_cast(to_type, *arg, device_type);
  175. } else {
  176. return c10::nullopt;
  177. }
  178. }
  179. // Overload to process TensorLists
  180. inline std::vector<Tensor> cached_cast(
  181. at::ScalarType to_type,
  182. const TensorList& arg,
  183. DeviceType device_type = DeviceType::CUDA) {
  184. std::vector<Tensor> vec;
  185. vec.reserve(arg.size());
  186. for (const auto& t : arg) {
  187. vec.emplace_back(cached_cast(to_type, t, device_type));
  188. }
  189. return vec;
  190. }
  191. inline std::vector<Tensor> cached_cast(
  192. at::ScalarType to_type,
  193. const ITensorListRef& arg,
  194. DeviceType device_type = DeviceType::CUDA) {
  195. std::vector<Tensor> vec;
  196. vec.reserve(arg.size());
  197. for (const auto& t : arg) {
  198. vec.emplace_back(cached_cast(to_type, t, device_type));
  199. }
  200. return vec;
  201. }
  202. // Template to catch non-Tensor args.
  203. template <typename T>
  204. inline T cached_cast(
  205. at::ScalarType to_type,
  206. T arg,
  207. DeviceType device_type = DeviceType::CUDA) {
  208. return arg;
  209. }
  210. /*******************************************************
  211. Logic to flip an output dtype flag.
  212. Keep it simple for now by assuming only one such flag is
  213. present in the argument list. If I ever need a function
  214. with more than flag I'll figure out something else.
  215. The policy is:
  216. If the user has explicity specified a dtype, respect it.
  217. Otherwise, set it to the autocast type.
  218. ********************************************************/
  219. // Overload to catch dtype flags
  220. c10::optional<ScalarType> inline set_opt_dtype(
  221. at::ScalarType to_type,
  222. const c10::optional<ScalarType>& dtype) {
  223. return dtype.has_value() ? dtype : to_type;
  224. }
  225. // Template to catch other args
  226. template <typename T>
  227. inline T set_opt_dtype(at::ScalarType to_type, T arg) {
  228. return arg;
  229. }
  230. template <typename... Args>
  231. inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
  232. return is_eligible(arg);
  233. }
  234. template <typename... Args>
  235. inline at::ScalarType type_from_firstarg(
  236. at::ScalarType to_type,
  237. const Tensor& arg,
  238. Args... args) {
  239. return (is_eligible(arg) ? to_type : arg.scalar_type());
  240. }
  241. } // namespace autocast
  242. } // namespace at