numext.cpp 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. template<typename T, typename U>
  11. bool check_if_equal_or_nans(const T& actual, const U& expected) {
  12. return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
  13. }
  14. template<typename T, typename U>
  15. bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
  16. return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
  17. && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
  18. }
  19. template<typename T, typename U>
  20. bool test_is_equal_or_nans(const T& actual, const U& expected)
  21. {
  22. if (check_if_equal_or_nans(actual, expected)) {
  23. return true;
  24. }
  25. // false:
  26. std::cerr
  27. << "\n actual = " << actual
  28. << "\n expected = " << expected << "\n\n";
  29. return false;
  30. }
  31. #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
  32. template<typename T>
  33. void check_abs() {
  34. typedef typename NumTraits<T>::Real Real;
  35. Real zero(0);
  36. if(NumTraits<T>::IsSigned)
  37. VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
  38. VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
  39. VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
  40. for(int k=0; k<100; ++k)
  41. {
  42. T x = internal::random<T>();
  43. if(!internal::is_same<T,bool>::value)
  44. x = x/Real(2);
  45. if(NumTraits<T>::IsSigned)
  46. {
  47. VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
  48. VERIFY( numext::abs(-x) >= zero );
  49. }
  50. VERIFY( numext::abs(x) >= zero );
  51. VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
  52. }
  53. }
  54. template<typename T>
  55. void check_arg() {
  56. typedef typename NumTraits<T>::Real Real;
  57. VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
  58. VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
  59. for(int k=0; k<100; ++k)
  60. {
  61. T x = internal::random<T>();
  62. Real y = numext::arg(x);
  63. VERIFY_IS_APPROX( y, std::arg(x) );
  64. }
  65. }
  66. template<typename T>
  67. struct check_sqrt_impl {
  68. static void run() {
  69. for (int i=0; i<1000; ++i) {
  70. const T x = numext::abs(internal::random<T>());
  71. const T sqrtx = numext::sqrt(x);
  72. VERIFY_IS_APPROX(sqrtx*sqrtx, x);
  73. }
  74. // Corner cases.
  75. const T zero = T(0);
  76. const T one = T(1);
  77. const T inf = std::numeric_limits<T>::infinity();
  78. const T nan = std::numeric_limits<T>::quiet_NaN();
  79. VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
  80. VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
  81. VERIFY((numext::isnan)(numext::sqrt(nan)));
  82. VERIFY((numext::isnan)(numext::sqrt(-one)));
  83. }
  84. };
  85. template<typename T>
  86. struct check_sqrt_impl<std::complex<T> > {
  87. static void run() {
  88. typedef typename std::complex<T> ComplexT;
  89. for (int i=0; i<1000; ++i) {
  90. const ComplexT x = internal::random<ComplexT>();
  91. const ComplexT sqrtx = numext::sqrt(x);
  92. VERIFY_IS_APPROX(sqrtx*sqrtx, x);
  93. }
  94. // Corner cases.
  95. const T zero = T(0);
  96. const T one = T(1);
  97. const T inf = std::numeric_limits<T>::infinity();
  98. const T nan = std::numeric_limits<T>::quiet_NaN();
  99. // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
  100. const int kNumCorners = 20;
  101. const ComplexT corners[kNumCorners][2] = {
  102. {ComplexT(zero, zero), ComplexT(zero, zero)},
  103. {ComplexT(-zero, zero), ComplexT(zero, zero)},
  104. {ComplexT(zero, -zero), ComplexT(zero, zero)},
  105. {ComplexT(-zero, -zero), ComplexT(zero, zero)},
  106. {ComplexT(one, inf), ComplexT(inf, inf)},
  107. {ComplexT(nan, inf), ComplexT(inf, inf)},
  108. {ComplexT(one, -inf), ComplexT(inf, -inf)},
  109. {ComplexT(nan, -inf), ComplexT(inf, -inf)},
  110. {ComplexT(-inf, one), ComplexT(zero, inf)},
  111. {ComplexT(inf, one), ComplexT(inf, zero)},
  112. {ComplexT(-inf, -one), ComplexT(zero, -inf)},
  113. {ComplexT(inf, -one), ComplexT(inf, -zero)},
  114. {ComplexT(-inf, nan), ComplexT(nan, inf)},
  115. {ComplexT(inf, nan), ComplexT(inf, nan)},
  116. {ComplexT(zero, nan), ComplexT(nan, nan)},
  117. {ComplexT(one, nan), ComplexT(nan, nan)},
  118. {ComplexT(nan, zero), ComplexT(nan, nan)},
  119. {ComplexT(nan, one), ComplexT(nan, nan)},
  120. {ComplexT(nan, -one), ComplexT(nan, nan)},
  121. {ComplexT(nan, nan), ComplexT(nan, nan)},
  122. };
  123. for (int i=0; i<kNumCorners; ++i) {
  124. const ComplexT& x = corners[i][0];
  125. const ComplexT sqrtx = corners[i][1];
  126. VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
  127. }
  128. }
  129. };
  130. template<typename T>
  131. void check_sqrt() {
  132. check_sqrt_impl<T>::run();
  133. }
  134. template<typename T>
  135. struct check_rsqrt_impl {
  136. static void run() {
  137. const T zero = T(0);
  138. const T one = T(1);
  139. const T inf = std::numeric_limits<T>::infinity();
  140. const T nan = std::numeric_limits<T>::quiet_NaN();
  141. for (int i=0; i<1000; ++i) {
  142. const T x = numext::abs(internal::random<T>());
  143. const T rsqrtx = numext::rsqrt(x);
  144. const T invx = one / x;
  145. VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
  146. }
  147. // Corner cases.
  148. VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
  149. VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
  150. VERIFY((numext::isnan)(numext::rsqrt(nan)));
  151. VERIFY((numext::isnan)(numext::rsqrt(-one)));
  152. }
  153. };
  154. template<typename T>
  155. struct check_rsqrt_impl<std::complex<T> > {
  156. static void run() {
  157. typedef typename std::complex<T> ComplexT;
  158. const T zero = T(0);
  159. const T one = T(1);
  160. const T inf = std::numeric_limits<T>::infinity();
  161. const T nan = std::numeric_limits<T>::quiet_NaN();
  162. for (int i=0; i<1000; ++i) {
  163. const ComplexT x = internal::random<ComplexT>();
  164. const ComplexT invx = ComplexT(one, zero) / x;
  165. const ComplexT rsqrtx = numext::rsqrt(x);
  166. VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
  167. }
  168. // GCC and MSVC differ in their treatment of 1/(0 + 0i)
  169. // GCC/clang = (inf, nan)
  170. // MSVC = (nan, nan)
  171. // and 1 / (x + inf i)
  172. // GCC/clang = (0, 0)
  173. // MSVC = (nan, nan)
  174. #if (EIGEN_COMP_GNUC)
  175. {
  176. const int kNumCorners = 20;
  177. const ComplexT corners[kNumCorners][2] = {
  178. // Only consistent across GCC, clang
  179. {ComplexT(zero, zero), ComplexT(zero, zero)},
  180. {ComplexT(-zero, zero), ComplexT(zero, zero)},
  181. {ComplexT(zero, -zero), ComplexT(zero, zero)},
  182. {ComplexT(-zero, -zero), ComplexT(zero, zero)},
  183. {ComplexT(one, inf), ComplexT(inf, inf)},
  184. {ComplexT(nan, inf), ComplexT(inf, inf)},
  185. {ComplexT(one, -inf), ComplexT(inf, -inf)},
  186. {ComplexT(nan, -inf), ComplexT(inf, -inf)},
  187. // Consistent across GCC, clang, MSVC
  188. {ComplexT(-inf, one), ComplexT(zero, inf)},
  189. {ComplexT(inf, one), ComplexT(inf, zero)},
  190. {ComplexT(-inf, -one), ComplexT(zero, -inf)},
  191. {ComplexT(inf, -one), ComplexT(inf, -zero)},
  192. {ComplexT(-inf, nan), ComplexT(nan, inf)},
  193. {ComplexT(inf, nan), ComplexT(inf, nan)},
  194. {ComplexT(zero, nan), ComplexT(nan, nan)},
  195. {ComplexT(one, nan), ComplexT(nan, nan)},
  196. {ComplexT(nan, zero), ComplexT(nan, nan)},
  197. {ComplexT(nan, one), ComplexT(nan, nan)},
  198. {ComplexT(nan, -one), ComplexT(nan, nan)},
  199. {ComplexT(nan, nan), ComplexT(nan, nan)},
  200. };
  201. for (int i=0; i<kNumCorners; ++i) {
  202. const ComplexT& x = corners[i][0];
  203. const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
  204. VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
  205. }
  206. }
  207. #endif
  208. }
  209. };
  210. template<typename T>
  211. void check_rsqrt() {
  212. check_rsqrt_impl<T>::run();
  213. }
  214. EIGEN_DECLARE_TEST(numext) {
  215. for(int k=0; k<g_repeat; ++k)
  216. {
  217. CALL_SUBTEST( check_abs<bool>() );
  218. CALL_SUBTEST( check_abs<signed char>() );
  219. CALL_SUBTEST( check_abs<unsigned char>() );
  220. CALL_SUBTEST( check_abs<short>() );
  221. CALL_SUBTEST( check_abs<unsigned short>() );
  222. CALL_SUBTEST( check_abs<int>() );
  223. CALL_SUBTEST( check_abs<unsigned int>() );
  224. CALL_SUBTEST( check_abs<long>() );
  225. CALL_SUBTEST( check_abs<unsigned long>() );
  226. CALL_SUBTEST( check_abs<half>() );
  227. CALL_SUBTEST( check_abs<bfloat16>() );
  228. CALL_SUBTEST( check_abs<float>() );
  229. CALL_SUBTEST( check_abs<double>() );
  230. CALL_SUBTEST( check_abs<long double>() );
  231. CALL_SUBTEST( check_abs<std::complex<float> >() );
  232. CALL_SUBTEST( check_abs<std::complex<double> >() );
  233. CALL_SUBTEST( check_arg<std::complex<float> >() );
  234. CALL_SUBTEST( check_arg<std::complex<double> >() );
  235. CALL_SUBTEST( check_sqrt<float>() );
  236. CALL_SUBTEST( check_sqrt<double>() );
  237. CALL_SUBTEST( check_sqrt<std::complex<float> >() );
  238. CALL_SUBTEST( check_sqrt<std::complex<double> >() );
  239. CALL_SUBTEST( check_rsqrt<float>() );
  240. CALL_SUBTEST( check_rsqrt<double>() );
  241. CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
  242. CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
  243. }
  244. }