123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- #pragma once
- // Defines the bloat16 type (brain floating-point). This representation uses
- // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
- #include <c10/macros/Macros.h>
- #include <cmath>
- #include <cstring>
- #if defined(__CUDACC__) && !defined(USE_ROCM)
- #include <cuda_bf16.h>
- #endif
- #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
- #if defined(CL_SYCL_LANGUAGE_VERSION)
- #include <CL/sycl.hpp> // for SYCL 1.2.1
- #else
- #include <sycl/sycl.hpp> // for SYCL 2020
- #endif
- #include <ext/oneapi/bfloat16.hpp>
- #endif
- namespace c10 {
- namespace detail {
- inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
- float res = 0;
- uint32_t tmp = src;
- tmp <<= 16;
- #if defined(USE_ROCM)
- float* tempRes;
- // We should be using memcpy in order to respect the strict aliasing rule
- // but it fails in the HIP environment.
- tempRes = reinterpret_cast<float*>(&tmp);
- res = *tempRes;
- #else
- std::memcpy(&res, &tmp, sizeof(tmp));
- #endif
- return res;
- }
- inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
- uint32_t res = 0;
- #if defined(USE_ROCM)
- // We should be using memcpy in order to respect the strict aliasing rule
- // but it fails in the HIP environment.
- uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
- res = *tempRes;
- #else
- std::memcpy(&res, &src, sizeof(res));
- #endif
- return res >> 16;
- }
- inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
- #if defined(USE_ROCM)
- if (src != src) {
- #elif defined(_MSC_VER)
- if (isnan(src)) {
- #else
- if (std::isnan(src)) {
- #endif
- return UINT16_C(0x7FC0);
- } else {
- union {
- uint32_t U32;
- float F32;
- };
- F32 = src;
- uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
- return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
- }
- }
- } // namespace detail
- struct alignas(2) BFloat16 {
- uint16_t x;
- // HIP wants __host__ __device__ tag, CUDA does not
- #if defined(USE_ROCM)
- C10_HOST_DEVICE BFloat16() = default;
- #else
- BFloat16() = default;
- #endif
- struct from_bits_t {};
- static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
- return from_bits_t();
- }
- constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
- : x(bits){};
- inline C10_HOST_DEVICE BFloat16(float value);
- inline C10_HOST_DEVICE operator float() const;
- #if defined(__CUDACC__) && !defined(USE_ROCM)
- inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
- explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
- #endif
- #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
- inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
- explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
- #endif
- };
- } // namespace c10
- #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
|