BFloat16-inl.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <limits>
  4. C10_CLANG_DIAGNOSTIC_PUSH()
  5. #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
  6. C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
  7. #endif
  8. #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
  9. #if defined(CL_SYCL_LANGUAGE_VERSION)
  10. #include <CL/sycl.hpp> // for SYCL 1.2.1
  11. #else
  12. #include <sycl/sycl.hpp> // for SYCL 2020
  13. #endif
  14. #include <ext/oneapi/bfloat16.hpp>
  15. #endif
  16. namespace c10 {
  17. /// Constructors
  18. inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
  19. :
  20. #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
  21. __CUDA_ARCH__ >= 800
  22. x(__bfloat16_as_ushort(__float2bfloat16(value)))
  23. #elif defined(__SYCL_DEVICE_ONLY__) && \
  24. defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
  25. x(sycl::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
  26. #else
  27. // RNE by default
  28. x(detail::round_to_nearest_even(value))
  29. #endif
  30. {
  31. }
  32. /// Implicit conversions
  33. inline C10_HOST_DEVICE BFloat16::operator float() const {
  34. #if defined(__CUDACC__) && !defined(USE_ROCM)
  35. return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
  36. #elif defined(__SYCL_DEVICE_ONLY__) && \
  37. defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
  38. return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
  39. #else
  40. return detail::f32_from_bits(x);
  41. #endif
  42. }
  43. #if defined(__CUDACC__) && !defined(USE_ROCM)
  44. inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
  45. x = *reinterpret_cast<const unsigned short*>(&value);
  46. }
  47. inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
  48. return *reinterpret_cast<const __nv_bfloat16*>(&x);
  49. }
  50. #endif
  51. #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
  52. inline C10_HOST_DEVICE BFloat16::BFloat16(
  53. const sycl::ext::oneapi::bfloat16& value) {
  54. x = *reinterpret_cast<const unsigned short*>(&value);
  55. }
  56. inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
  57. return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
  58. }
  59. #endif
  60. // CUDA intrinsics
  61. #if defined(__CUDACC__) || defined(__HIPCC__)
  62. inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
  63. #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  64. return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
  65. #else
  66. return *ptr;
  67. #endif
  68. }
  69. #endif
  70. /// Arithmetic
  71. inline C10_HOST_DEVICE BFloat16
  72. operator+(const BFloat16& a, const BFloat16& b) {
  73. return static_cast<float>(a) + static_cast<float>(b);
  74. }
  75. inline C10_HOST_DEVICE BFloat16
  76. operator-(const BFloat16& a, const BFloat16& b) {
  77. return static_cast<float>(a) - static_cast<float>(b);
  78. }
  79. inline C10_HOST_DEVICE BFloat16
  80. operator*(const BFloat16& a, const BFloat16& b) {
  81. return static_cast<float>(a) * static_cast<float>(b);
  82. }
  83. inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
  84. __ubsan_ignore_float_divide_by_zero__ {
  85. return static_cast<float>(a) / static_cast<float>(b);
  86. }
  87. inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
  88. return -static_cast<float>(a);
  89. }
  90. inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
  91. a = a + b;
  92. return a;
  93. }
  94. inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
  95. a = a - b;
  96. return a;
  97. }
  98. inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
  99. a = a * b;
  100. return a;
  101. }
  102. inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
  103. a = a / b;
  104. return a;
  105. }
  106. inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
  107. a.x = a.x | b.x;
  108. return a;
  109. }
  110. inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
  111. a.x = a.x ^ b.x;
  112. return a;
  113. }
  114. inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
  115. a.x = a.x & b.x;
  116. return a;
  117. }
  118. /// Arithmetic with floats
  119. inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
  120. return static_cast<float>(a) + b;
  121. }
  122. inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
  123. return static_cast<float>(a) - b;
  124. }
  125. inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
  126. return static_cast<float>(a) * b;
  127. }
  128. inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
  129. return static_cast<float>(a) / b;
  130. }
  131. inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
  132. return a + static_cast<float>(b);
  133. }
  134. inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
  135. return a - static_cast<float>(b);
  136. }
  137. inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
  138. return a * static_cast<float>(b);
  139. }
  140. inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
  141. return a / static_cast<float>(b);
  142. }
  143. inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
  144. return a += static_cast<float>(b);
  145. }
  146. inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
  147. return a -= static_cast<float>(b);
  148. }
  149. inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
  150. return a *= static_cast<float>(b);
  151. }
  152. inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
  153. return a /= static_cast<float>(b);
  154. }
  155. /// Arithmetic with doubles
  156. inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
  157. return static_cast<double>(a) + b;
  158. }
  159. inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
  160. return static_cast<double>(a) - b;
  161. }
  162. inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
  163. return static_cast<double>(a) * b;
  164. }
  165. inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
  166. return static_cast<double>(a) / b;
  167. }
  168. inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
  169. return a + static_cast<double>(b);
  170. }
  171. inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
  172. return a - static_cast<double>(b);
  173. }
  174. inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
  175. return a * static_cast<double>(b);
  176. }
  177. inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
  178. return a / static_cast<double>(b);
  179. }
  180. /// Arithmetic with ints
  181. inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
  182. return a + static_cast<BFloat16>(b);
  183. }
  184. inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
  185. return a - static_cast<BFloat16>(b);
  186. }
  187. inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
  188. return a * static_cast<BFloat16>(b);
  189. }
  190. inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
  191. return a / static_cast<BFloat16>(b);
  192. }
  193. inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
  194. return static_cast<BFloat16>(a) + b;
  195. }
  196. inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
  197. return static_cast<BFloat16>(a) - b;
  198. }
  199. inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
  200. return static_cast<BFloat16>(a) * b;
  201. }
  202. inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
  203. return static_cast<BFloat16>(a) / b;
  204. }
  205. //// Arithmetic with int64_t
  206. inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
  207. return a + static_cast<BFloat16>(b);
  208. }
  209. inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
  210. return a - static_cast<BFloat16>(b);
  211. }
  212. inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
  213. return a * static_cast<BFloat16>(b);
  214. }
  215. inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
  216. return a / static_cast<BFloat16>(b);
  217. }
  218. inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
  219. return static_cast<BFloat16>(a) + b;
  220. }
  221. inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
  222. return static_cast<BFloat16>(a) - b;
  223. }
  224. inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
  225. return static_cast<BFloat16>(a) * b;
  226. }
  227. inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
  228. return static_cast<BFloat16>(a) / b;
  229. }
  230. // Overloading < and > operators, because std::max and std::min use them.
  231. inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
  232. return float(lhs) > float(rhs);
  233. }
  234. inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
  235. return float(lhs) < float(rhs);
  236. }
  237. } // namespace c10
  238. namespace std {
  239. template <>
  240. class numeric_limits<c10::BFloat16> {
  241. public:
  242. static constexpr bool is_signed = true;
  243. static constexpr bool is_specialized = true;
  244. static constexpr bool is_integer = false;
  245. static constexpr bool is_exact = false;
  246. static constexpr bool has_infinity = true;
  247. static constexpr bool has_quiet_NaN = true;
  248. static constexpr bool has_signaling_NaN = true;
  249. static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
  250. static constexpr auto has_denorm_loss =
  251. numeric_limits<float>::has_denorm_loss;
  252. static constexpr auto round_style = numeric_limits<float>::round_style;
  253. static constexpr bool is_iec559 = false;
  254. static constexpr bool is_bounded = true;
  255. static constexpr bool is_modulo = false;
  256. static constexpr int digits = 8;
  257. static constexpr int digits10 = 2;
  258. static constexpr int max_digits10 = 4;
  259. static constexpr int radix = 2;
  260. static constexpr int min_exponent = -125;
  261. static constexpr int min_exponent10 = -37;
  262. static constexpr int max_exponent = 128;
  263. static constexpr int max_exponent10 = 38;
  264. static constexpr auto traps = numeric_limits<float>::traps;
  265. static constexpr auto tinyness_before =
  266. numeric_limits<float>::tinyness_before;
  267. static constexpr c10::BFloat16 min() {
  268. return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
  269. }
  270. static constexpr c10::BFloat16 lowest() {
  271. return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
  272. }
  273. static constexpr c10::BFloat16 max() {
  274. return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
  275. }
  276. static constexpr c10::BFloat16 epsilon() {
  277. return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
  278. }
  279. static constexpr c10::BFloat16 round_error() {
  280. return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
  281. }
  282. static constexpr c10::BFloat16 infinity() {
  283. return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
  284. }
  285. static constexpr c10::BFloat16 quiet_NaN() {
  286. return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
  287. }
  288. static constexpr c10::BFloat16 signaling_NaN() {
  289. return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
  290. }
  291. static constexpr c10::BFloat16 denorm_min() {
  292. return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
  293. }
  294. };
  295. } // namespace std
  296. C10_CLANG_DIAGNOSTIC_POP()