MathConstants.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <c10/util/BFloat16.h>
  4. #include <c10/util/Half.h>
  5. C10_CLANG_DIAGNOSTIC_PUSH()
  6. #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
  7. C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
  8. #endif
  9. namespace c10 {
  10. // TODO: Replace me with inline constexpr variable when C++17 becomes available
  11. namespace detail {
  12. template <typename T>
  13. C10_HOST_DEVICE inline constexpr T e() {
  14. return static_cast<T>(2.718281828459045235360287471352662);
  15. }
  16. template <typename T>
  17. C10_HOST_DEVICE inline constexpr T euler() {
  18. return static_cast<T>(0.577215664901532860606512090082402);
  19. }
  20. template <typename T>
  21. C10_HOST_DEVICE inline constexpr T frac_1_pi() {
  22. return static_cast<T>(0.318309886183790671537767526745028);
  23. }
  24. template <typename T>
  25. C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() {
  26. return static_cast<T>(0.564189583547756286948079451560772);
  27. }
  28. template <typename T>
  29. C10_HOST_DEVICE inline constexpr T frac_sqrt_2() {
  30. return static_cast<T>(0.707106781186547524400844362104849);
  31. }
  32. template <typename T>
  33. C10_HOST_DEVICE inline constexpr T frac_sqrt_3() {
  34. return static_cast<T>(0.577350269189625764509148780501957);
  35. }
  36. template <typename T>
  37. C10_HOST_DEVICE inline constexpr T golden_ratio() {
  38. return static_cast<T>(1.618033988749894848204586834365638);
  39. }
  40. template <typename T>
  41. C10_HOST_DEVICE inline constexpr T ln_10() {
  42. return static_cast<T>(2.302585092994045684017991454684364);
  43. }
  44. template <typename T>
  45. C10_HOST_DEVICE inline constexpr T ln_2() {
  46. return static_cast<T>(0.693147180559945309417232121458176);
  47. }
  48. template <typename T>
  49. C10_HOST_DEVICE inline constexpr T log_10_e() {
  50. return static_cast<T>(0.434294481903251827651128918916605);
  51. }
  52. template <typename T>
  53. C10_HOST_DEVICE inline constexpr T log_2_e() {
  54. return static_cast<T>(1.442695040888963407359924681001892);
  55. }
  56. template <typename T>
  57. C10_HOST_DEVICE inline constexpr T pi() {
  58. return static_cast<T>(3.141592653589793238462643383279502);
  59. }
  60. template <typename T>
  61. C10_HOST_DEVICE inline constexpr T sqrt_2() {
  62. return static_cast<T>(1.414213562373095048801688724209698);
  63. }
  64. template <typename T>
  65. C10_HOST_DEVICE inline constexpr T sqrt_3() {
  66. return static_cast<T>(1.732050807568877293527446341505872);
  67. }
  68. template <>
  69. C10_HOST_DEVICE inline constexpr BFloat16 pi<BFloat16>() {
  70. // According to
  71. // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values
  72. // pi is encoded as 4049
  73. return BFloat16(0x4049, BFloat16::from_bits());
  74. }
  75. template <>
  76. C10_HOST_DEVICE inline constexpr Half pi<Half>() {
  77. return Half(0x4248, Half::from_bits());
  78. }
  79. } // namespace detail
  80. template <typename T>
  81. constexpr T e = c10::detail::e<T>();
  82. template <typename T>
  83. constexpr T euler = c10::detail::euler<T>();
  84. template <typename T>
  85. constexpr T frac_1_pi = c10::detail::frac_1_pi<T>();
  86. template <typename T>
  87. constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi<T>();
  88. template <typename T>
  89. constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2<T>();
  90. template <typename T>
  91. constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3<T>();
  92. template <typename T>
  93. constexpr T golden_ratio = c10::detail::golden_ratio<T>();
  94. template <typename T>
  95. constexpr T ln_10 = c10::detail::ln_10<T>();
  96. template <typename T>
  97. constexpr T ln_2 = c10::detail::ln_2<T>();
  98. template <typename T>
  99. constexpr T log_10_e = c10::detail::log_10_e<T>();
  100. template <typename T>
  101. constexpr T log_2_e = c10::detail::log_2_e<T>();
  102. template <typename T>
  103. constexpr T pi = c10::detail::pi<T>();
  104. template <typename T>
  105. constexpr T sqrt_2 = c10::detail::sqrt_2<T>();
  106. template <typename T>
  107. constexpr T sqrt_3 = c10::detail::sqrt_3<T>();
  108. } // namespace c10
  109. C10_CLANG_DIAGNOSTIC_POP()