NumericLimits.cuh 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #pragma once
  2. #include <cuda.h>
  3. #include <limits.h>
  4. #include <math.h>
  5. #include <float.h>
  6. // NumericLimits.cuh is a holder for numeric limits definitions of commonly used
  7. // types. This header is very specific to ROCm HIP and may be removed in the future.
  8. // This header is derived from the legacy THCNumerics.cuh.
  9. // The lower_bound and upper_bound constants are same as lowest and max for
  10. // integral types, but are -inf and +inf for floating point types. They are
  11. // useful in implementing min, max, etc.
  12. namespace at {
  13. template <typename T>
  14. struct numeric_limits {
  15. };
  16. // WARNING: the following at::numeric_limits definitions are there only to support
  17. // HIP compilation for the moment. Use std::numeric_limits if you are not
  18. // compiling for ROCm.
  19. // from @colesbury: "The functions on numeric_limits aren't marked with
  20. // __device__ which is why they don't work with ROCm. CUDA allows them
  21. // because they're constexpr."
  22. namespace {
  23. // ROCm doesn't like INFINITY too.
  24. constexpr double inf = INFINITY;
  25. }
  26. template <>
  27. struct numeric_limits<bool> {
  28. static inline __host__ __device__ bool lowest() { return false; }
  29. static inline __host__ __device__ bool max() { return true; }
  30. static inline __host__ __device__ bool lower_bound() { return false; }
  31. static inline __host__ __device__ bool upper_bound() { return true; }
  32. };
  33. template <>
  34. struct numeric_limits<uint8_t> {
  35. static inline __host__ __device__ uint8_t lowest() { return 0; }
  36. static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
  37. static inline __host__ __device__ uint8_t lower_bound() { return 0; }
  38. static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
  39. };
  40. template <>
  41. struct numeric_limits<int8_t> {
  42. static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
  43. static inline __host__ __device__ int8_t max() { return INT8_MAX; }
  44. static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
  45. static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
  46. };
  47. template <>
  48. struct numeric_limits<int16_t> {
  49. static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
  50. static inline __host__ __device__ int16_t max() { return INT16_MAX; }
  51. static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
  52. static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
  53. };
  54. template <>
  55. struct numeric_limits<int32_t> {
  56. static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
  57. static inline __host__ __device__ int32_t max() { return INT32_MAX; }
  58. static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
  59. static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
  60. };
  61. template <>
  62. struct numeric_limits<int64_t> {
  63. #ifdef _MSC_VER
  64. static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
  65. static inline __host__ __device__ int64_t max() { return _I64_MAX; }
  66. static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
  67. static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
  68. #else
  69. static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
  70. static inline __host__ __device__ int64_t max() { return INT64_MAX; }
  71. static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
  72. static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
  73. #endif
  74. };
  75. template <>
  76. struct numeric_limits<at::Half> {
  77. static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
  78. static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
  79. static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
  80. static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
  81. };
  82. template <>
  83. struct numeric_limits<at::BFloat16> {
  84. static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
  85. static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
  86. static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
  87. static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
  88. };
  89. template <>
  90. struct numeric_limits<float> {
  91. static inline __host__ __device__ float lowest() { return -FLT_MAX; }
  92. static inline __host__ __device__ float max() { return FLT_MAX; }
  93. static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
  94. static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
  95. };
  96. template <>
  97. struct numeric_limits<double> {
  98. static inline __host__ __device__ double lowest() { return -DBL_MAX; }
  99. static inline __host__ __device__ double max() { return DBL_MAX; }
  100. static inline __host__ __device__ double lower_bound() { return -inf; }
  101. static inline __host__ __device__ double upper_bound() { return inf; }
  102. };
  103. } // namespace at