onnxruntime_float16.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. #pragma once
  4. #include <stdint.h>
  5. #include <cmath>
  6. #include <cstring>
  7. #include <limits>
  8. namespace onnxruntime_float16 {
  9. namespace detail {
  10. enum class endian {
  11. #if defined(_WIN32)
  12. little = 0,
  13. big = 1,
  14. native = little,
  15. #elif defined(__GNUC__) || defined(__clang__)
  16. little = __ORDER_LITTLE_ENDIAN__,
  17. big = __ORDER_BIG_ENDIAN__,
  18. native = __BYTE_ORDER__,
  19. #else
  20. #error onnxruntime_float16::detail::endian is not implemented in this environment.
  21. #endif
  22. };
  23. static_assert(
  24. endian::native == endian::little || endian::native == endian::big,
  25. "Only little-endian or big-endian native byte orders are supported.");
  26. } // namespace detail
  27. /// <summary>
  28. /// Shared implementation between public and internal classes. CRTP pattern.
  29. /// </summary>
  30. template <class Derived>
  31. struct Float16Impl {
  32. protected:
  33. /// <summary>
  34. /// Converts from float to uint16_t float16 representation
  35. /// </summary>
  36. /// <param name="v"></param>
  37. /// <returns></returns>
  38. constexpr static uint16_t ToUint16Impl(float v) noexcept;
  39. /// <summary>
  40. /// Converts float16 to float
  41. /// </summary>
  42. /// <returns>float representation of float16 value</returns>
  43. float ToFloatImpl() const noexcept;
  44. /// <summary>
  45. /// Creates an instance that represents absolute value.
  46. /// </summary>
  47. /// <returns>Absolute value</returns>
  48. uint16_t AbsImpl() const noexcept {
  49. return static_cast<uint16_t>(val & ~kSignMask);
  50. }
  51. /// <summary>
  52. /// Creates a new instance with the sign flipped.
  53. /// </summary>
  54. /// <returns>Flipped sign instance</returns>
  55. uint16_t NegateImpl() const noexcept {
  56. return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
  57. }
  58. public:
  59. // uint16_t special values
  60. static constexpr uint16_t kSignMask = 0x8000U;
  61. static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
  62. static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
  63. static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
  64. static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
  65. static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
  66. static constexpr uint16_t kEpsilonBits = 0x4170U;
  67. static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
  68. static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
  69. static constexpr uint16_t kOneBits = 0x3C00U;
  70. static constexpr uint16_t kMinusOneBits = 0xBC00U;
  71. uint16_t val{0};
  72. Float16Impl() = default;
  73. /// <summary>
  74. /// Checks if the value is negative
  75. /// </summary>
  76. /// <returns>true if negative</returns>
  77. bool IsNegative() const noexcept {
  78. return static_cast<int16_t>(val) < 0;
  79. }
  80. /// <summary>
  81. /// Tests if the value is NaN
  82. /// </summary>
  83. /// <returns>true if NaN</returns>
  84. bool IsNaN() const noexcept {
  85. return AbsImpl() > kPositiveInfinityBits;
  86. }
  87. /// <summary>
  88. /// Tests if the value is finite
  89. /// </summary>
  90. /// <returns>true if finite</returns>
  91. bool IsFinite() const noexcept {
  92. return AbsImpl() < kPositiveInfinityBits;
  93. }
  94. /// <summary>
  95. /// Tests if the value represents positive infinity.
  96. /// </summary>
  97. /// <returns>true if positive infinity</returns>
  98. bool IsPositiveInfinity() const noexcept {
  99. return val == kPositiveInfinityBits;
  100. }
  101. /// <summary>
  102. /// Tests if the value represents negative infinity
  103. /// </summary>
  104. /// <returns>true if negative infinity</returns>
  105. bool IsNegativeInfinity() const noexcept {
  106. return val == kNegativeInfinityBits;
  107. }
  108. /// <summary>
  109. /// Tests if the value is either positive or negative infinity.
  110. /// </summary>
  111. /// <returns>True if absolute value is infinity</returns>
  112. bool IsInfinity() const noexcept {
  113. return AbsImpl() == kPositiveInfinityBits;
  114. }
  115. /// <summary>
  116. /// Tests if the value is NaN or zero. Useful for comparisons.
  117. /// </summary>
  118. /// <returns>True if NaN or zero.</returns>
  119. bool IsNaNOrZero() const noexcept {
  120. auto abs = AbsImpl();
  121. return (abs == 0 || abs > kPositiveInfinityBits);
  122. }
  123. /// <summary>
  124. /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
  125. /// </summary>
  126. /// <returns>True if so</returns>
  127. bool IsNormal() const noexcept {
  128. auto abs = AbsImpl();
  129. return (abs < kPositiveInfinityBits) // is finite
  130. && (abs != 0) // is not zero
  131. && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
  132. }
  133. /// <summary>
  134. /// Tests if the value is subnormal (denormal).
  135. /// </summary>
  136. /// <returns>True if so</returns>
  137. bool IsSubnormal() const noexcept {
  138. auto abs = AbsImpl();
  139. return (abs < kPositiveInfinityBits) // is finite
  140. && (abs != 0) // is not zero
  141. && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
  142. }
  143. /// <summary>
  144. /// Creates an instance that represents absolute value.
  145. /// </summary>
  146. /// <returns>Absolute value</returns>
  147. Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
  148. /// <summary>
  149. /// Creates a new instance with the sign flipped.
  150. /// </summary>
  151. /// <returns>Flipped sign instance</returns>
  152. Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
  153. /// <summary>
  154. /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
  155. /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
  156. /// and therefore equivalent, if the resulting value is still zero.
  157. /// </summary>
  158. /// <param name="lhs">first value</param>
  159. /// <param name="rhs">second value</param>
  160. /// <returns>True if both arguments represent zero</returns>
  161. static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
  162. return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
  163. }
  164. bool operator==(const Float16Impl& rhs) const noexcept {
  165. if (IsNaN() || rhs.IsNaN()) {
  166. // IEEE defines that NaN is not equal to anything, including itself.
  167. return false;
  168. }
  169. return val == rhs.val;
  170. }
  171. bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
  172. bool operator<(const Float16Impl& rhs) const noexcept {
  173. if (IsNaN() || rhs.IsNaN()) {
  174. // IEEE defines that NaN is unordered with respect to everything, including itself.
  175. return false;
  176. }
  177. const bool left_is_negative = IsNegative();
  178. if (left_is_negative != rhs.IsNegative()) {
  179. // When the signs of left and right differ, we know that left is less than right if it is
  180. // the negative value. The exception to this is if both values are zero, in which case IEEE
  181. // says they should be equal, even if the signs differ.
  182. return left_is_negative && !AreZero(*this, rhs);
  183. }
  184. return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
  185. }
  186. };
  187. // The following Float16_t conversions are based on the code from
  188. // Eigen library.
  189. // The conversion routines are Copyright (c) Fabian Giesen, 2016.
  190. // The original license follows:
  191. //
  192. // Copyright (c) Fabian Giesen, 2016
  193. // All rights reserved.
  194. // Redistribution and use in source and binary forms, with or without
  195. // modification, are permitted.
  196. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  197. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  198. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  199. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  200. // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  201. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  202. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  203. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  204. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  205. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  206. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  207. namespace detail {
  208. union float32_bits {
  209. unsigned int u;
  210. float f;
  211. };
  212. } // namespace detail
  213. template <class Derived>
  214. inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
  215. detail::float32_bits f{};
  216. f.f = v;
  217. constexpr detail::float32_bits f32infty = {255 << 23};
  218. constexpr detail::float32_bits f16max = {(127 + 16) << 23};
  219. constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
  220. constexpr unsigned int sign_mask = 0x80000000u;
  221. uint16_t val = static_cast<uint16_t>(0x0u);
  222. unsigned int sign = f.u & sign_mask;
  223. f.u ^= sign;
  224. // NOTE all the integer compares in this function can be safely
  225. // compiled into signed compares since all operands are below
  226. // 0x80000000. Important if you want fast straight SSE2 code
  227. // (since there's no unsigned PCMPGTD).
  228. if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
  229. val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
  230. } else { // (De)normalized number or zero
  231. if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
  232. // use a magic value to align our 10 mantissa bits at the bottom of
  233. // the float. as long as FP addition is round-to-nearest-even this
  234. // just works.
  235. f.f += denorm_magic.f;
  236. // and one integer subtract of the bias later, we have our final float!
  237. val = static_cast<uint16_t>(f.u - denorm_magic.u);
  238. } else {
  239. unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
  240. // update exponent, rounding bias part 1
  241. // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
  242. // without arithmetic overflow.
  243. f.u += 0xc8000fffU;
  244. // rounding bias part 2
  245. f.u += mant_odd;
  246. // take the bits!
  247. val = static_cast<uint16_t>(f.u >> 13);
  248. }
  249. }
  250. val |= static_cast<uint16_t>(sign >> 16);
  251. return val;
  252. }
  253. template <class Derived>
  254. inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
  255. constexpr detail::float32_bits magic = {113 << 23};
  256. constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
  257. detail::float32_bits o{};
  258. o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
  259. unsigned int exp = shifted_exp & o.u; // just the exponent
  260. o.u += (127 - 15) << 23; // exponent adjust
  261. // handle exponent special cases
  262. if (exp == shifted_exp) { // Inf/NaN?
  263. o.u += (128 - 16) << 23; // extra exp adjust
  264. } else if (exp == 0) { // Zero/Denormal?
  265. o.u += 1 << 23; // extra exp adjust
  266. o.f -= magic.f; // re-normalize
  267. }
  268. // Attempt to workaround the Internal Compiler Error on ARM64
  269. // for bitwise | operator, including std::bitset
  270. #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
  271. if (IsNegative()) {
  272. return -o.f;
  273. }
  274. #else
  275. // original code:
  276. o.u |= (val & 0x8000U) << 16U; // sign bit
  277. #endif
  278. return o.f;
  279. }
  280. /// Shared implementation between public and internal classes. CRTP pattern.
  281. template <class Derived>
  282. struct BFloat16Impl {
  283. protected:
  284. /// <summary>
  285. /// Converts from float to uint16_t float16 representation
  286. /// </summary>
  287. /// <param name="v"></param>
  288. /// <returns></returns>
  289. static uint16_t ToUint16Impl(float v) noexcept;
  290. /// <summary>
  291. /// Converts bfloat16 to float
  292. /// </summary>
  293. /// <returns>float representation of bfloat16 value</returns>
  294. float ToFloatImpl() const noexcept;
  295. /// <summary>
  296. /// Creates an instance that represents absolute value.
  297. /// </summary>
  298. /// <returns>Absolute value</returns>
  299. uint16_t AbsImpl() const noexcept {
  300. return static_cast<uint16_t>(val & ~kSignMask);
  301. }
  302. /// <summary>
  303. /// Creates a new instance with the sign flipped.
  304. /// </summary>
  305. /// <returns>Flipped sign instance</returns>
  306. uint16_t NegateImpl() const noexcept {
  307. return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
  308. }
  309. public:
  310. // uint16_t special values
  311. static constexpr uint16_t kSignMask = 0x8000U;
  312. static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
  313. static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
  314. static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
  315. static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
  316. static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
  317. static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
  318. static constexpr uint16_t kEpsilonBits = 0x0080U;
  319. static constexpr uint16_t kMinValueBits = 0xFF7FU;
  320. static constexpr uint16_t kMaxValueBits = 0x7F7FU;
  321. static constexpr uint16_t kRoundToNearest = 0x7FFFU;
  322. static constexpr uint16_t kOneBits = 0x3F80U;
  323. static constexpr uint16_t kMinusOneBits = 0xBF80U;
  324. uint16_t val{0};
  325. BFloat16Impl() = default;
  326. /// <summary>
  327. /// Checks if the value is negative
  328. /// </summary>
  329. /// <returns>true if negative</returns>
  330. bool IsNegative() const noexcept {
  331. return static_cast<int16_t>(val) < 0;
  332. }
  333. /// <summary>
  334. /// Tests if the value is NaN
  335. /// </summary>
  336. /// <returns>true if NaN</returns>
  337. bool IsNaN() const noexcept {
  338. return AbsImpl() > kPositiveInfinityBits;
  339. }
  340. /// <summary>
  341. /// Tests if the value is finite
  342. /// </summary>
  343. /// <returns>true if finite</returns>
  344. bool IsFinite() const noexcept {
  345. return AbsImpl() < kPositiveInfinityBits;
  346. }
  347. /// <summary>
  348. /// Tests if the value represents positive infinity.
  349. /// </summary>
  350. /// <returns>true if positive infinity</returns>
  351. bool IsPositiveInfinity() const noexcept {
  352. return val == kPositiveInfinityBits;
  353. }
  354. /// <summary>
  355. /// Tests if the value represents negative infinity
  356. /// </summary>
  357. /// <returns>true if negative infinity</returns>
  358. bool IsNegativeInfinity() const noexcept {
  359. return val == kNegativeInfinityBits;
  360. }
  361. /// <summary>
  362. /// Tests if the value is either positive or negative infinity.
  363. /// </summary>
  364. /// <returns>True if absolute value is infinity</returns>
  365. bool IsInfinity() const noexcept {
  366. return AbsImpl() == kPositiveInfinityBits;
  367. }
  368. /// <summary>
  369. /// Tests if the value is NaN or zero. Useful for comparisons.
  370. /// </summary>
  371. /// <returns>True if NaN or zero.</returns>
  372. bool IsNaNOrZero() const noexcept {
  373. auto abs = AbsImpl();
  374. return (abs == 0 || abs > kPositiveInfinityBits);
  375. }
  376. /// <summary>
  377. /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
  378. /// </summary>
  379. /// <returns>True if so</returns>
  380. bool IsNormal() const noexcept {
  381. auto abs = AbsImpl();
  382. return (abs < kPositiveInfinityBits) // is finite
  383. && (abs != 0) // is not zero
  384. && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
  385. }
  386. /// <summary>
  387. /// Tests if the value is subnormal (denormal).
  388. /// </summary>
  389. /// <returns>True if so</returns>
  390. bool IsSubnormal() const noexcept {
  391. auto abs = AbsImpl();
  392. return (abs < kPositiveInfinityBits) // is finite
  393. && (abs != 0) // is not zero
  394. && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
  395. }
  396. /// <summary>
  397. /// Creates an instance that represents absolute value.
  398. /// </summary>
  399. /// <returns>Absolute value</returns>
  400. Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
  401. /// <summary>
  402. /// Creates a new instance with the sign flipped.
  403. /// </summary>
  404. /// <returns>Flipped sign instance</returns>
  405. Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
  406. /// <summary>
  407. /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
  408. /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
  409. /// and therefore equivalent, if the resulting value is still zero.
  410. /// </summary>
  411. /// <param name="lhs">first value</param>
  412. /// <param name="rhs">second value</param>
  413. /// <returns>True if both arguments represent zero</returns>
  414. static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
  415. // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
  416. // for two values by or'ing the private bits together and stripping the sign. They are both zero,
  417. // and therefore equivalent, if the resulting value is still zero.
  418. return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
  419. }
  420. };
  421. template <class Derived>
  422. inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
  423. uint16_t result;
  424. if (std::isnan(v)) {
  425. result = kPositiveQNaNBits;
  426. } else {
  427. auto get_msb_half = [](float fl) {
  428. uint16_t result;
  429. #ifdef __cpp_if_constexpr
  430. if constexpr (detail::endian::native == detail::endian::little) {
  431. #else
  432. if (detail::endian::native == detail::endian::little) {
  433. #endif
  434. std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
  435. } else {
  436. std::memcpy(&result, &fl, sizeof(uint16_t));
  437. }
  438. return result;
  439. };
  440. uint16_t upper_bits = get_msb_half(v);
  441. union {
  442. uint32_t U32;
  443. float F32;
  444. };
  445. F32 = v;
  446. U32 += (upper_bits & 1) + kRoundToNearest;
  447. result = get_msb_half(F32);
  448. }
  449. return result;
  450. }
  451. template <class Derived>
  452. inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
  453. if (IsNaN()) {
  454. return std::numeric_limits<float>::quiet_NaN();
  455. }
  456. float result;
  457. char* const first = reinterpret_cast<char*>(&result);
  458. char* const second = first + sizeof(uint16_t);
  459. #ifdef __cpp_if_constexpr
  460. if constexpr (detail::endian::native == detail::endian::little) {
  461. #else
  462. if (detail::endian::native == detail::endian::little) {
  463. #endif
  464. std::memset(first, 0, sizeof(uint16_t));
  465. std::memcpy(second, &val, sizeof(uint16_t));
  466. } else {
  467. std::memcpy(first, &val, sizeof(uint16_t));
  468. std::memset(second, 0, sizeof(uint16_t));
  469. }
  470. return result;
  471. }
  472. } // namespace onnxruntime_float16