123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540 |
- // Copyright (c) Microsoft Corporation. All rights reserved.
- // Licensed under the MIT License.
- #pragma once
- #include <stdint.h>
- #include <cmath>
- #include <cstring>
- #include <limits>
- namespace onnxruntime_float16 {
- namespace detail {
- enum class endian {
- #if defined(_WIN32)
- little = 0,
- big = 1,
- native = little,
- #elif defined(__GNUC__) || defined(__clang__)
- little = __ORDER_LITTLE_ENDIAN__,
- big = __ORDER_BIG_ENDIAN__,
- native = __BYTE_ORDER__,
- #else
- #error onnxruntime_float16::detail::endian is not implemented in this environment.
- #endif
- };
- static_assert(
- endian::native == endian::little || endian::native == endian::big,
- "Only little-endian or big-endian native byte orders are supported.");
- } // namespace detail
- /// <summary>
- /// Shared implementation between public and internal classes. CRTP pattern.
- /// </summary>
- template <class Derived>
- struct Float16Impl {
- protected:
- /// <summary>
- /// Converts from float to uint16_t float16 representation
- /// </summary>
- /// <param name="v"></param>
- /// <returns></returns>
- constexpr static uint16_t ToUint16Impl(float v) noexcept;
- /// <summary>
- /// Converts float16 to float
- /// </summary>
- /// <returns>float representation of float16 value</returns>
- float ToFloatImpl() const noexcept;
- /// <summary>
- /// Creates an instance that represents absolute value.
- /// </summary>
- /// <returns>Absolute value</returns>
- uint16_t AbsImpl() const noexcept {
- return static_cast<uint16_t>(val & ~kSignMask);
- }
- /// <summary>
- /// Creates a new instance with the sign flipped.
- /// </summary>
- /// <returns>Flipped sign instance</returns>
- uint16_t NegateImpl() const noexcept {
- return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
- }
- public:
- // uint16_t special values
- static constexpr uint16_t kSignMask = 0x8000U;
- static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
- static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
- static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
- static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
- static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
- static constexpr uint16_t kEpsilonBits = 0x4170U;
- static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
- static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
- static constexpr uint16_t kOneBits = 0x3C00U;
- static constexpr uint16_t kMinusOneBits = 0xBC00U;
- uint16_t val{0};
- Float16Impl() = default;
- /// <summary>
- /// Checks if the value is negative
- /// </summary>
- /// <returns>true if negative</returns>
- bool IsNegative() const noexcept {
- return static_cast<int16_t>(val) < 0;
- }
- /// <summary>
- /// Tests if the value is NaN
- /// </summary>
- /// <returns>true if NaN</returns>
- bool IsNaN() const noexcept {
- return AbsImpl() > kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value is finite
- /// </summary>
- /// <returns>true if finite</returns>
- bool IsFinite() const noexcept {
- return AbsImpl() < kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value represents positive infinity.
- /// </summary>
- /// <returns>true if positive infinity</returns>
- bool IsPositiveInfinity() const noexcept {
- return val == kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value represents negative infinity
- /// </summary>
- /// <returns>true if negative infinity</returns>
- bool IsNegativeInfinity() const noexcept {
- return val == kNegativeInfinityBits;
- }
- /// <summary>
- /// Tests if the value is either positive or negative infinity.
- /// </summary>
- /// <returns>True if absolute value is infinity</returns>
- bool IsInfinity() const noexcept {
- return AbsImpl() == kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value is NaN or zero. Useful for comparisons.
- /// </summary>
- /// <returns>True if NaN or zero.</returns>
- bool IsNaNOrZero() const noexcept {
- auto abs = AbsImpl();
- return (abs == 0 || abs > kPositiveInfinityBits);
- }
- /// <summary>
- /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
- /// </summary>
- /// <returns>True if so</returns>
- bool IsNormal() const noexcept {
- auto abs = AbsImpl();
- return (abs < kPositiveInfinityBits) // is finite
- && (abs != 0) // is not zero
- && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
- }
- /// <summary>
- /// Tests if the value is subnormal (denormal).
- /// </summary>
- /// <returns>True if so</returns>
- bool IsSubnormal() const noexcept {
- auto abs = AbsImpl();
- return (abs < kPositiveInfinityBits) // is finite
- && (abs != 0) // is not zero
- && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
- }
- /// <summary>
- /// Creates an instance that represents absolute value.
- /// </summary>
- /// <returns>Absolute value</returns>
- Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
- /// <summary>
- /// Creates a new instance with the sign flipped.
- /// </summary>
- /// <returns>Flipped sign instance</returns>
- Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
- /// <summary>
- /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
- /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
- /// and therefore equivalent, if the resulting value is still zero.
- /// </summary>
- /// <param name="lhs">first value</param>
- /// <param name="rhs">second value</param>
- /// <returns>True if both arguments represent zero</returns>
- static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
- return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
- }
- bool operator==(const Float16Impl& rhs) const noexcept {
- if (IsNaN() || rhs.IsNaN()) {
- // IEEE defines that NaN is not equal to anything, including itself.
- return false;
- }
- return val == rhs.val;
- }
- bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
- bool operator<(const Float16Impl& rhs) const noexcept {
- if (IsNaN() || rhs.IsNaN()) {
- // IEEE defines that NaN is unordered with respect to everything, including itself.
- return false;
- }
- const bool left_is_negative = IsNegative();
- if (left_is_negative != rhs.IsNegative()) {
- // When the signs of left and right differ, we know that left is less than right if it is
- // the negative value. The exception to this is if both values are zero, in which case IEEE
- // says they should be equal, even if the signs differ.
- return left_is_negative && !AreZero(*this, rhs);
- }
- return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
- }
- };
- // The following Float16_t conversions are based on the code from
- // Eigen library.
- // The conversion routines are Copyright (c) Fabian Giesen, 2016.
- // The original license follows:
- //
- // Copyright (c) Fabian Giesen, 2016
- // All rights reserved.
- // Redistribution and use in source and binary forms, with or without
- // modification, are permitted.
- // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- namespace detail {
- union float32_bits {
- unsigned int u;
- float f;
- };
- } // namespace detail
- template <class Derived>
- inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
- detail::float32_bits f{};
- f.f = v;
- constexpr detail::float32_bits f32infty = {255 << 23};
- constexpr detail::float32_bits f16max = {(127 + 16) << 23};
- constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
- constexpr unsigned int sign_mask = 0x80000000u;
- uint16_t val = static_cast<uint16_t>(0x0u);
- unsigned int sign = f.u & sign_mask;
- f.u ^= sign;
- // NOTE all the integer compares in this function can be safely
- // compiled into signed compares since all operands are below
- // 0x80000000. Important if you want fast straight SSE2 code
- // (since there's no unsigned PCMPGTD).
- if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
- val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
- } else { // (De)normalized number or zero
- if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
- // use a magic value to align our 10 mantissa bits at the bottom of
- // the float. as long as FP addition is round-to-nearest-even this
- // just works.
- f.f += denorm_magic.f;
- // and one integer subtract of the bias later, we have our final float!
- val = static_cast<uint16_t>(f.u - denorm_magic.u);
- } else {
- unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
- // update exponent, rounding bias part 1
- // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
- // without arithmetic overflow.
- f.u += 0xc8000fffU;
- // rounding bias part 2
- f.u += mant_odd;
- // take the bits!
- val = static_cast<uint16_t>(f.u >> 13);
- }
- }
- val |= static_cast<uint16_t>(sign >> 16);
- return val;
- }
- template <class Derived>
- inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
- constexpr detail::float32_bits magic = {113 << 23};
- constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
- detail::float32_bits o{};
- o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
- unsigned int exp = shifted_exp & o.u; // just the exponent
- o.u += (127 - 15) << 23; // exponent adjust
- // handle exponent special cases
- if (exp == shifted_exp) { // Inf/NaN?
- o.u += (128 - 16) << 23; // extra exp adjust
- } else if (exp == 0) { // Zero/Denormal?
- o.u += 1 << 23; // extra exp adjust
- o.f -= magic.f; // re-normalize
- }
- // Attempt to workaround the Internal Compiler Error on ARM64
- // for bitwise | operator, including std::bitset
- #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
- if (IsNegative()) {
- return -o.f;
- }
- #else
- // original code:
- o.u |= (val & 0x8000U) << 16U; // sign bit
- #endif
- return o.f;
- }
- /// Shared implementation between public and internal classes. CRTP pattern.
- template <class Derived>
- struct BFloat16Impl {
- protected:
- /// <summary>
- /// Converts from float to uint16_t float16 representation
- /// </summary>
- /// <param name="v"></param>
- /// <returns></returns>
- static uint16_t ToUint16Impl(float v) noexcept;
- /// <summary>
- /// Converts bfloat16 to float
- /// </summary>
- /// <returns>float representation of bfloat16 value</returns>
- float ToFloatImpl() const noexcept;
- /// <summary>
- /// Creates an instance that represents absolute value.
- /// </summary>
- /// <returns>Absolute value</returns>
- uint16_t AbsImpl() const noexcept {
- return static_cast<uint16_t>(val & ~kSignMask);
- }
- /// <summary>
- /// Creates a new instance with the sign flipped.
- /// </summary>
- /// <returns>Flipped sign instance</returns>
- uint16_t NegateImpl() const noexcept {
- return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
- }
- public:
- // uint16_t special values
- static constexpr uint16_t kSignMask = 0x8000U;
- static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
- static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
- static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
- static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
- static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
- static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
- static constexpr uint16_t kEpsilonBits = 0x0080U;
- static constexpr uint16_t kMinValueBits = 0xFF7FU;
- static constexpr uint16_t kMaxValueBits = 0x7F7FU;
- static constexpr uint16_t kRoundToNearest = 0x7FFFU;
- static constexpr uint16_t kOneBits = 0x3F80U;
- static constexpr uint16_t kMinusOneBits = 0xBF80U;
- uint16_t val{0};
- BFloat16Impl() = default;
- /// <summary>
- /// Checks if the value is negative
- /// </summary>
- /// <returns>true if negative</returns>
- bool IsNegative() const noexcept {
- return static_cast<int16_t>(val) < 0;
- }
- /// <summary>
- /// Tests if the value is NaN
- /// </summary>
- /// <returns>true if NaN</returns>
- bool IsNaN() const noexcept {
- return AbsImpl() > kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value is finite
- /// </summary>
- /// <returns>true if finite</returns>
- bool IsFinite() const noexcept {
- return AbsImpl() < kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value represents positive infinity.
- /// </summary>
- /// <returns>true if positive infinity</returns>
- bool IsPositiveInfinity() const noexcept {
- return val == kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value represents negative infinity
- /// </summary>
- /// <returns>true if negative infinity</returns>
- bool IsNegativeInfinity() const noexcept {
- return val == kNegativeInfinityBits;
- }
- /// <summary>
- /// Tests if the value is either positive or negative infinity.
- /// </summary>
- /// <returns>True if absolute value is infinity</returns>
- bool IsInfinity() const noexcept {
- return AbsImpl() == kPositiveInfinityBits;
- }
- /// <summary>
- /// Tests if the value is NaN or zero. Useful for comparisons.
- /// </summary>
- /// <returns>True if NaN or zero.</returns>
- bool IsNaNOrZero() const noexcept {
- auto abs = AbsImpl();
- return (abs == 0 || abs > kPositiveInfinityBits);
- }
- /// <summary>
- /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
- /// </summary>
- /// <returns>True if so</returns>
- bool IsNormal() const noexcept {
- auto abs = AbsImpl();
- return (abs < kPositiveInfinityBits) // is finite
- && (abs != 0) // is not zero
- && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
- }
- /// <summary>
- /// Tests if the value is subnormal (denormal).
- /// </summary>
- /// <returns>True if so</returns>
- bool IsSubnormal() const noexcept {
- auto abs = AbsImpl();
- return (abs < kPositiveInfinityBits) // is finite
- && (abs != 0) // is not zero
- && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
- }
- /// <summary>
- /// Creates an instance that represents absolute value.
- /// </summary>
- /// <returns>Absolute value</returns>
- Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
- /// <summary>
- /// Creates a new instance with the sign flipped.
- /// </summary>
- /// <returns>Flipped sign instance</returns>
- Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
- /// <summary>
- /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
- /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
- /// and therefore equivalent, if the resulting value is still zero.
- /// </summary>
- /// <param name="lhs">first value</param>
- /// <param name="rhs">second value</param>
- /// <returns>True if both arguments represent zero</returns>
- static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
- // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
- // for two values by or'ing the private bits together and stripping the sign. They are both zero,
- // and therefore equivalent, if the resulting value is still zero.
- return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
- }
- };
- template <class Derived>
- inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
- uint16_t result;
- if (std::isnan(v)) {
- result = kPositiveQNaNBits;
- } else {
- auto get_msb_half = [](float fl) {
- uint16_t result;
- #ifdef __cpp_if_constexpr
- if constexpr (detail::endian::native == detail::endian::little) {
- #else
- if (detail::endian::native == detail::endian::little) {
- #endif
- std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
- } else {
- std::memcpy(&result, &fl, sizeof(uint16_t));
- }
- return result;
- };
- uint16_t upper_bits = get_msb_half(v);
- union {
- uint32_t U32;
- float F32;
- };
- F32 = v;
- U32 += (upper_bits & 1) + kRoundToNearest;
- result = get_msb_half(F32);
- }
- return result;
- }
- template <class Derived>
- inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
- if (IsNaN()) {
- return std::numeric_limits<float>::quiet_NaN();
- }
- float result;
- char* const first = reinterpret_cast<char*>(&result);
- char* const second = first + sizeof(uint16_t);
- #ifdef __cpp_if_constexpr
- if constexpr (detail::endian::native == detail::endian::little) {
- #else
- if (detail::endian::native == detail::endian::little) {
- #endif
- std::memset(first, 0, sizeof(uint16_t));
- std::memcpy(second, &val, sizeof(uint16_t));
- } else {
- std::memcpy(first, &val, sizeof(uint16_t));
- std::memset(second, 0, sizeof(uint16_t));
- }
- return result;
- }
- } // namespace onnxruntime_float16
|