123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- #pragma once
- #include <ATen/ATen.h>
- namespace at {
- namespace autocast {
- TORCH_API bool is_enabled();
- TORCH_API void set_enabled(bool enabled);
- TORCH_API void clear_cache();
- TORCH_API int increment_nesting();
- TORCH_API int decrement_nesting();
- TORCH_API bool is_cpu_enabled();
- TORCH_API void set_cpu_enabled(bool enabled);
- TORCH_API at::ScalarType get_autocast_gpu_dtype();
- TORCH_API at::ScalarType get_autocast_cpu_dtype();
- TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
- TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
- TORCH_API bool is_xpu_enabled();
- TORCH_API void set_xpu_enabled(bool enabled);
- TORCH_API at::ScalarType get_autocast_xpu_dtype();
- TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
- TORCH_API bool is_hpu_enabled();
- TORCH_API void set_hpu_enabled(bool enabled);
- TORCH_API at::ScalarType get_autocast_hpu_dtype();
- TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
- TORCH_API bool is_autocast_cache_enabled();
- TORCH_API void set_autocast_cache_enabled(bool enabled);
- namespace {
- bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
- switch (device_type) {
- case DeviceType::CUDA:
- return (tensor.is_cuda() || tensor.is_xla()) &&
- tensor.is_floating_point();
- case DeviceType::CPU:
- return (tensor.is_cpu() || tensor.is_mkldnn()) &&
- tensor.is_floating_point();
- case DeviceType::XPU:
- return tensor.is_xpu() && tensor.is_floating_point();
- case DeviceType::HPU:
- return tensor.is_hpu() && tensor.is_floating_point();
- default:
- return false;
- }
- }
- } // namespace
- inline DispatchKey get_autocast_dispatch_key_from_device_type(
- DeviceType device_type) {
- switch (device_type) {
- case DeviceType::CUDA:
- return DispatchKey::Autocast;
- case DeviceType::CPU:
- return DispatchKey::AutocastCPU;
- case DeviceType::XPU:
- return DispatchKey::AutocastXPU;
- case DeviceType::HPU:
- return DispatchKey::AutocastHPU;
- default:
- throw std::runtime_error(
- "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
- }
- }
- inline at::ScalarType get_lower_precision_fp_from_device_type(
- DeviceType device_type) {
- switch (device_type) {
- case DeviceType::CUDA:
- return get_autocast_gpu_dtype();
- case DeviceType::CPU:
- return get_autocast_cpu_dtype();
- case DeviceType::XPU:
- return get_autocast_xpu_dtype();
- case DeviceType::HPU:
- return get_autocast_hpu_dtype();
- default:
- throw std::runtime_error(
- "unknown device type for autocast in get_lower_precision_fp_from_device_type");
- }
- }
- /********************************************************************
- Logic to extract the promote type from any Tensor or TensorList args.
- ********************************************************************/
- // Overload to catch Tensor args.
- // If nextArg is floating-point, compare its scalar_type with our
- // current best guess for the promote type, and update if necessary.
- inline at::ScalarType prioritize(
- at::ScalarType current,
- const Tensor& nextArg,
- DeviceType device_type = DeviceType::CUDA) {
- if (current == at::kDouble) {
- AT_ERROR("promote type is double in at::autocast::prioritize");
- return current;
- }
- at::ScalarType lower_precision_fp =
- get_lower_precision_fp_from_device_type(device_type);
- if (is_autocast_eligible(nextArg, device_type)) {
- auto next = nextArg.scalar_type();
- if (next == at::kDouble) {
- return current; // ignores double tensors
- } else if (current == at::kFloat || next == at::kFloat) {
- return at::kFloat; // prioritizes float over lower_precision_fp
- } else if (current == lower_precision_fp && next == lower_precision_fp) {
- return lower_precision_fp;
- } else {
- AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
- return current;
- }
- } else {
- return current;
- }
- }
- // Overload to catch TensorList args (for e.g. cat, stack).
- // Reuses the overload above to process each Tensor in the list.
- inline at::ScalarType prioritize(
- at::ScalarType current,
- const TensorList& list,
- DeviceType device_type = DeviceType::CUDA) {
- for (const auto& tensor : list) {
- current = prioritize(current, tensor, device_type);
- }
- return current;
- }
- inline at::ScalarType prioritize(
- at::ScalarType current,
- const ITensorListRef& list,
- DeviceType device_type = DeviceType::CUDA) {
- for (const auto& tensor : list) {
- current = prioritize(current, tensor, device_type);
- }
- return current;
- }
- // Template to catch non-Tensor args (no-op that returns current best guess)
- template <typename T>
- inline at::ScalarType prioritize(
- at::ScalarType current,
- T nextArg,
- DeviceType device_type = DeviceType::CUDA) {
- return current;
- }
- // Overload for the tail case.
- inline at::ScalarType promote_type(
- at::ScalarType current,
- DeviceType device_type) {
- return current;
- }
- // Unpack args and determine if incoming lower_precision_fp tensors need to be
- // promoted to float32. Non-Tensor arguments are ignored.
- template <typename Arg0, typename... Args>
- inline at::ScalarType promote_type(
- at::ScalarType current,
- DeviceType device_type,
- Arg0 arg0,
- Args... args) {
- auto new_current = prioritize(current, arg0, device_type);
- return promote_type(new_current, device_type, args...);
- }
- /****************************************************
- Logic to apply cached casting to any Tensor argument.
- ****************************************************/
- inline bool is_eligible(
- const Tensor& arg,
- DeviceType device_type = DeviceType::CUDA) {
- return (
- arg.defined() && is_autocast_eligible(arg, device_type) &&
- (arg.scalar_type() != at::kDouble));
- }
- // Overload to catch Tensor args
- TORCH_API Tensor cached_cast(
- at::ScalarType to_type,
- const Tensor& arg,
- DeviceType device_type = DeviceType::CUDA);
- // Overload to process optional<Tensor>
- inline c10::optional<Tensor> cached_cast(
- at::ScalarType to_type,
- const c10::optional<Tensor>& arg,
- DeviceType device_type = DeviceType::CUDA) {
- if (arg.has_value()) {
- return cached_cast(to_type, *arg, device_type);
- } else {
- return c10::nullopt;
- }
- }
- // Overload to process TensorLists
- inline std::vector<Tensor> cached_cast(
- at::ScalarType to_type,
- const TensorList& arg,
- DeviceType device_type = DeviceType::CUDA) {
- std::vector<Tensor> vec;
- vec.reserve(arg.size());
- for (const auto& t : arg) {
- vec.emplace_back(cached_cast(to_type, t, device_type));
- }
- return vec;
- }
- inline std::vector<Tensor> cached_cast(
- at::ScalarType to_type,
- const ITensorListRef& arg,
- DeviceType device_type = DeviceType::CUDA) {
- std::vector<Tensor> vec;
- vec.reserve(arg.size());
- for (const auto& t : arg) {
- vec.emplace_back(cached_cast(to_type, t, device_type));
- }
- return vec;
- }
- // Template to catch non-Tensor args.
- template <typename T>
- inline T cached_cast(
- at::ScalarType to_type,
- T arg,
- DeviceType device_type = DeviceType::CUDA) {
- return arg;
- }
- /*******************************************************
- Logic to flip an output dtype flag.
- Keep it simple for now by assuming only one such flag is
- present in the argument list. If I ever need a function
- with more than flag I'll figure out something else.
- The policy is:
- If the user has explicity specified a dtype, respect it.
- Otherwise, set it to the autocast type.
- ********************************************************/
- // Overload to catch dtype flags
- c10::optional<ScalarType> inline set_opt_dtype(
- at::ScalarType to_type,
- const c10::optional<ScalarType>& dtype) {
- return dtype.has_value() ? dtype : to_type;
- }
- // Template to catch other args
- template <typename T>
- inline T set_opt_dtype(at::ScalarType to_type, T arg) {
- return arg;
- }
- template <typename... Args>
- inline bool firstarg_is_eligible(const Tensor& arg, Args... args) {
- return is_eligible(arg);
- }
- template <typename... Args>
- inline at::ScalarType type_from_firstarg(
- at::ScalarType to_type,
- const Tensor& arg,
- Args... args) {
- return (is_eligible(arg) ? to_type : arg.scalar_type());
- }
- } // namespace autocast
- } // namespace at
|