BFloat16.h 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #pragma once
  2. // Defines the bloat16 type (brain floating-point). This representation uses
  3. // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
  4. #include <c10/macros/Macros.h>
  5. #include <cmath>
  6. #include <cstring>
  7. #if defined(__CUDACC__) && !defined(USE_ROCM)
  8. #include <cuda_bf16.h>
  9. #endif
  10. #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
  11. #if defined(CL_SYCL_LANGUAGE_VERSION)
  12. #include <CL/sycl.hpp> // for SYCL 1.2.1
  13. #else
  14. #include <sycl/sycl.hpp> // for SYCL 2020
  15. #endif
  16. #include <ext/oneapi/bfloat16.hpp>
  17. #endif
  18. namespace c10 {
  19. namespace detail {
  20. inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
  21. float res = 0;
  22. uint32_t tmp = src;
  23. tmp <<= 16;
  24. #if defined(USE_ROCM)
  25. float* tempRes;
  26. // We should be using memcpy in order to respect the strict aliasing rule
  27. // but it fails in the HIP environment.
  28. tempRes = reinterpret_cast<float*>(&tmp);
  29. res = *tempRes;
  30. #else
  31. std::memcpy(&res, &tmp, sizeof(tmp));
  32. #endif
  33. return res;
  34. }
  35. inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
  36. uint32_t res = 0;
  37. #if defined(USE_ROCM)
  38. // We should be using memcpy in order to respect the strict aliasing rule
  39. // but it fails in the HIP environment.
  40. uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
  41. res = *tempRes;
  42. #else
  43. std::memcpy(&res, &src, sizeof(res));
  44. #endif
  45. return res >> 16;
  46. }
  47. inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
  48. #if defined(USE_ROCM)
  49. if (src != src) {
  50. #elif defined(_MSC_VER)
  51. if (isnan(src)) {
  52. #else
  53. if (std::isnan(src)) {
  54. #endif
  55. return UINT16_C(0x7FC0);
  56. } else {
  57. union {
  58. uint32_t U32;
  59. float F32;
  60. };
  61. F32 = src;
  62. uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
  63. return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
  64. }
  65. }
  66. } // namespace detail
  67. struct alignas(2) BFloat16 {
  68. uint16_t x;
  69. // HIP wants __host__ __device__ tag, CUDA does not
  70. #if defined(USE_ROCM)
  71. C10_HOST_DEVICE BFloat16() = default;
  72. #else
  73. BFloat16() = default;
  74. #endif
  75. struct from_bits_t {};
  76. static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
  77. return from_bits_t();
  78. }
  79. constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
  80. : x(bits){};
  81. inline C10_HOST_DEVICE BFloat16(float value);
  82. inline C10_HOST_DEVICE operator float() const;
  83. #if defined(__CUDACC__) && !defined(USE_ROCM)
  84. inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
  85. explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
  86. #endif
  87. #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
  88. inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
  89. explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
  90. #endif
  91. };
  92. } // namespace c10
  93. #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep