safe_numerics.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <c10/util/ArrayRef.h>
  4. #include <iterator>
  5. #include <numeric>
  6. #include <type_traits>
  7. // GCC has __builtin_mul_overflow from before it supported __has_builtin
  8. #ifdef _MSC_VER
  9. #define C10_HAS_BUILTIN_OVERFLOW() (0)
  10. #include <c10/util/llvmMathExtras.h>
  11. #include <intrin.h>
  12. #else
  13. #define C10_HAS_BUILTIN_OVERFLOW() (1)
  14. #endif
  15. namespace c10 {
  16. C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
  17. #if C10_HAS_BUILTIN_OVERFLOW()
  18. return __builtin_add_overflow(a, b, out);
  19. #else
  20. unsigned long long tmp;
  21. #if defined(_M_IX86) || defined(_M_X64)
  22. auto carry = _addcarry_u64(0, a, b, &tmp);
  23. #else
  24. tmp = a + b;
  25. unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp);
  26. auto carry = vector >> 63;
  27. #endif
  28. *out = tmp;
  29. return carry;
  30. #endif
  31. }
  32. C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
  33. #if C10_HAS_BUILTIN_OVERFLOW()
  34. return __builtin_mul_overflow(a, b, out);
  35. #else
  36. *out = a * b;
  37. // This test isnt exact, but avoids doing integer division
  38. return (
  39. (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
  40. #endif
  41. }
  42. template <typename It>
  43. bool safe_multiplies_u64(It first, It last, uint64_t* out) {
  44. #if C10_HAS_BUILTIN_OVERFLOW()
  45. uint64_t prod = 1;
  46. bool overflow = false;
  47. for (; first != last; ++first) {
  48. overflow |= c10::mul_overflows(prod, *first, &prod);
  49. }
  50. *out = prod;
  51. return overflow;
  52. #else
  53. uint64_t prod = 1;
  54. uint64_t prod_log2 = 0;
  55. bool is_zero = false;
  56. for (; first != last; ++first) {
  57. auto x = static_cast<uint64_t>(*first);
  58. prod *= x;
  59. // log2(0) isn't valid, so need to track it specially
  60. is_zero |= (x == 0);
  61. prod_log2 += c10::llvm::Log2_64_Ceil(x);
  62. }
  63. *out = prod;
  64. // This test isnt exact, but avoids doing integer division
  65. return !is_zero && (prod_log2 >= 64);
  66. #endif
  67. }
  68. template <typename Container>
  69. bool safe_multiplies_u64(const Container& c, uint64_t* out) {
  70. return safe_multiplies_u64(c.begin(), c.end(), out);
  71. }
  72. } // namespace c10