complex_math.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. #if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
  2. #error \
  3. "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
  4. #endif
  5. namespace c10_complex_math {
  6. // Exponential functions
  7. template <typename T>
  8. C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
  9. #if defined(__CUDACC__) || defined(__HIPCC__)
  10. return static_cast<c10::complex<T>>(
  11. thrust::exp(static_cast<thrust::complex<T>>(x)));
  12. #else
  13. return static_cast<c10::complex<T>>(
  14. std::exp(static_cast<std::complex<T>>(x)));
  15. #endif
  16. }
  17. template <typename T>
  18. C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
  19. #if defined(__CUDACC__) || defined(__HIPCC__)
  20. return static_cast<c10::complex<T>>(
  21. thrust::log(static_cast<thrust::complex<T>>(x)));
  22. #else
  23. return static_cast<c10::complex<T>>(
  24. std::log(static_cast<std::complex<T>>(x)));
  25. #endif
  26. }
  27. template <typename T>
  28. C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
  29. #if defined(__CUDACC__) || defined(__HIPCC__)
  30. return static_cast<c10::complex<T>>(
  31. thrust::log10(static_cast<thrust::complex<T>>(x)));
  32. #else
  33. return static_cast<c10::complex<T>>(
  34. std::log10(static_cast<std::complex<T>>(x)));
  35. #endif
  36. }
  37. template <typename T>
  38. C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
  39. const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
  40. return c10_complex_math::log(x) / log2;
  41. }
  42. // Power functions
  43. //
  44. #if defined(_LIBCPP_VERSION) || \
  45. (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
  46. namespace _detail {
  47. C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
  48. C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
  49. C10_API c10::complex<float> acos(const c10::complex<float>& in);
  50. C10_API c10::complex<double> acos(const c10::complex<double>& in);
  51. }; // namespace _detail
  52. #endif
  53. template <typename T>
  54. C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
  55. #if defined(__CUDACC__) || defined(__HIPCC__)
  56. return static_cast<c10::complex<T>>(
  57. thrust::sqrt(static_cast<thrust::complex<T>>(x)));
  58. #elif !( \
  59. defined(_LIBCPP_VERSION) || \
  60. (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
  61. return static_cast<c10::complex<T>>(
  62. std::sqrt(static_cast<std::complex<T>>(x)));
  63. #else
  64. return _detail::sqrt(x);
  65. #endif
  66. }
  67. template <typename T>
  68. C10_HOST_DEVICE inline c10::complex<T> pow(
  69. const c10::complex<T>& x,
  70. const c10::complex<T>& y) {
  71. #if defined(__CUDACC__) || defined(__HIPCC__)
  72. return static_cast<c10::complex<T>>(thrust::pow(
  73. static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
  74. #else
  75. return static_cast<c10::complex<T>>(std::pow(
  76. static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
  77. #endif
  78. }
  79. template <typename T>
  80. C10_HOST_DEVICE inline c10::complex<T> pow(
  81. const c10::complex<T>& x,
  82. const T& y) {
  83. #if defined(__CUDACC__) || defined(__HIPCC__)
  84. return static_cast<c10::complex<T>>(
  85. thrust::pow(static_cast<thrust::complex<T>>(x), y));
  86. #else
  87. return static_cast<c10::complex<T>>(
  88. std::pow(static_cast<std::complex<T>>(x), y));
  89. #endif
  90. }
  91. template <typename T>
  92. C10_HOST_DEVICE inline c10::complex<T> pow(
  93. const T& x,
  94. const c10::complex<T>& y) {
  95. #if defined(__CUDACC__) || defined(__HIPCC__)
  96. return static_cast<c10::complex<T>>(
  97. thrust::pow(x, static_cast<thrust::complex<T>>(y)));
  98. #else
  99. return static_cast<c10::complex<T>>(
  100. std::pow(x, static_cast<std::complex<T>>(y)));
  101. #endif
  102. }
  103. template <typename T, typename U>
  104. C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
  105. const c10::complex<T>& x,
  106. const c10::complex<U>& y) {
  107. #if defined(__CUDACC__) || defined(__HIPCC__)
  108. return static_cast<c10::complex<T>>(thrust::pow(
  109. static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
  110. #else
  111. return static_cast<c10::complex<T>>(std::pow(
  112. static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
  113. #endif
  114. }
  115. template <typename T, typename U>
  116. C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
  117. const c10::complex<T>& x,
  118. const U& y) {
  119. #if defined(__CUDACC__) || defined(__HIPCC__)
  120. return static_cast<c10::complex<T>>(
  121. thrust::pow(static_cast<thrust::complex<T>>(x), y));
  122. #else
  123. return static_cast<c10::complex<T>>(
  124. std::pow(static_cast<std::complex<T>>(x), y));
  125. #endif
  126. }
  127. template <typename T, typename U>
  128. C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
  129. const T& x,
  130. const c10::complex<U>& y) {
  131. #if defined(__CUDACC__) || defined(__HIPCC__)
  132. return static_cast<c10::complex<T>>(
  133. thrust::pow(x, static_cast<thrust::complex<T>>(y)));
  134. #else
  135. return static_cast<c10::complex<T>>(
  136. std::pow(x, static_cast<std::complex<T>>(y)));
  137. #endif
  138. }
  139. // Trigonometric functions
  140. template <typename T>
  141. C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
  142. #if defined(__CUDACC__) || defined(__HIPCC__)
  143. return static_cast<c10::complex<T>>(
  144. thrust::sin(static_cast<thrust::complex<T>>(x)));
  145. #else
  146. return static_cast<c10::complex<T>>(
  147. std::sin(static_cast<std::complex<T>>(x)));
  148. #endif
  149. }
  150. template <typename T>
  151. C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
  152. #if defined(__CUDACC__) || defined(__HIPCC__)
  153. return static_cast<c10::complex<T>>(
  154. thrust::cos(static_cast<thrust::complex<T>>(x)));
  155. #else
  156. return static_cast<c10::complex<T>>(
  157. std::cos(static_cast<std::complex<T>>(x)));
  158. #endif
  159. }
  160. template <typename T>
  161. C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
  162. #if defined(__CUDACC__) || defined(__HIPCC__)
  163. return static_cast<c10::complex<T>>(
  164. thrust::tan(static_cast<thrust::complex<T>>(x)));
  165. #else
  166. return static_cast<c10::complex<T>>(
  167. std::tan(static_cast<std::complex<T>>(x)));
  168. #endif
  169. }
  170. template <typename T>
  171. C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
  172. #if defined(__CUDACC__) || defined(__HIPCC__)
  173. return static_cast<c10::complex<T>>(
  174. thrust::asin(static_cast<thrust::complex<T>>(x)));
  175. #else
  176. return static_cast<c10::complex<T>>(
  177. std::asin(static_cast<std::complex<T>>(x)));
  178. #endif
  179. }
  180. template <typename T>
  181. C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
  182. #if defined(__CUDACC__) || defined(__HIPCC__)
  183. return static_cast<c10::complex<T>>(
  184. thrust::acos(static_cast<thrust::complex<T>>(x)));
  185. #elif !defined(_LIBCPP_VERSION)
  186. return static_cast<c10::complex<T>>(
  187. std::acos(static_cast<std::complex<T>>(x)));
  188. #else
  189. return _detail::acos(x);
  190. #endif
  191. }
  192. template <typename T>
  193. C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
  194. #if defined(__CUDACC__) || defined(__HIPCC__)
  195. return static_cast<c10::complex<T>>(
  196. thrust::atan(static_cast<thrust::complex<T>>(x)));
  197. #else
  198. return static_cast<c10::complex<T>>(
  199. std::atan(static_cast<std::complex<T>>(x)));
  200. #endif
  201. }
  202. // Hyperbolic functions
  203. template <typename T>
  204. C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
  205. #if defined(__CUDACC__) || defined(__HIPCC__)
  206. return static_cast<c10::complex<T>>(
  207. thrust::sinh(static_cast<thrust::complex<T>>(x)));
  208. #else
  209. return static_cast<c10::complex<T>>(
  210. std::sinh(static_cast<std::complex<T>>(x)));
  211. #endif
  212. }
  213. template <typename T>
  214. C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
  215. #if defined(__CUDACC__) || defined(__HIPCC__)
  216. return static_cast<c10::complex<T>>(
  217. thrust::cosh(static_cast<thrust::complex<T>>(x)));
  218. #else
  219. return static_cast<c10::complex<T>>(
  220. std::cosh(static_cast<std::complex<T>>(x)));
  221. #endif
  222. }
  223. template <typename T>
  224. C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
  225. #if defined(__CUDACC__) || defined(__HIPCC__)
  226. return static_cast<c10::complex<T>>(
  227. thrust::tanh(static_cast<thrust::complex<T>>(x)));
  228. #else
  229. return static_cast<c10::complex<T>>(
  230. std::tanh(static_cast<std::complex<T>>(x)));
  231. #endif
  232. }
  233. template <typename T>
  234. C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
  235. #if defined(__CUDACC__) || defined(__HIPCC__)
  236. return static_cast<c10::complex<T>>(
  237. thrust::asinh(static_cast<thrust::complex<T>>(x)));
  238. #else
  239. return static_cast<c10::complex<T>>(
  240. std::asinh(static_cast<std::complex<T>>(x)));
  241. #endif
  242. }
  243. template <typename T>
  244. C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
  245. #if defined(__CUDACC__) || defined(__HIPCC__)
  246. return static_cast<c10::complex<T>>(
  247. thrust::acosh(static_cast<thrust::complex<T>>(x)));
  248. #else
  249. return static_cast<c10::complex<T>>(
  250. std::acosh(static_cast<std::complex<T>>(x)));
  251. #endif
  252. }
  253. template <typename T>
  254. C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
  255. #if defined(__CUDACC__) || defined(__HIPCC__)
  256. return static_cast<c10::complex<T>>(
  257. thrust::atanh(static_cast<thrust::complex<T>>(x)));
  258. #else
  259. return static_cast<c10::complex<T>>(
  260. std::atanh(static_cast<std::complex<T>>(x)));
  261. #endif
  262. }
  263. template <typename T>
  264. C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
  265. // log1p(z) = log(1 + z)
  266. // Let's define 1 + z = r * e ^ (i * a), then we have
  267. // log(r * e ^ (i * a)) = log(r) + i * a
  268. // With z = x + iy, the term r can be written as
  269. // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
  270. // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
  271. // So, log(r) is
  272. // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
  273. // = 0.5 * log1p(x * (x + 2) + y ^ 2)
  274. // we need to use the expression only on certain condition to avoid overflow
  275. // and underflow from `(x * (x + 2) + y ^ 2)`
  276. T x = z.real();
  277. T y = z.imag();
  278. T zabs = std::abs(z);
  279. T theta = std::atan2(y, x + T(1));
  280. if (zabs < 0.5) {
  281. T r = x * (T(2) + x) + y * y;
  282. if (r == 0) { // handle underflow
  283. return {x, theta};
  284. }
  285. return {T(0.5) * std::log1p(r), theta};
  286. } else {
  287. T z0 = std::hypot(x + 1, y);
  288. return {std::log(z0), theta};
  289. }
  290. }
  291. } // namespace c10_complex_math
  292. using c10_complex_math::acos;
  293. using c10_complex_math::acosh;
  294. using c10_complex_math::asin;
  295. using c10_complex_math::asinh;
  296. using c10_complex_math::atan;
  297. using c10_complex_math::atanh;
  298. using c10_complex_math::cos;
  299. using c10_complex_math::cosh;
  300. using c10_complex_math::exp;
  301. using c10_complex_math::log;
  302. using c10_complex_math::log10;
  303. using c10_complex_math::log1p;
  304. using c10_complex_math::log2;
  305. using c10_complex_math::pow;
  306. using c10_complex_math::sin;
  307. using c10_complex_math::sinh;
  308. using c10_complex_math::sqrt;
  309. using c10_complex_math::tan;
  310. using c10_complex_math::tanh;
  311. namespace std {
  312. using c10_complex_math::acos;
  313. using c10_complex_math::acosh;
  314. using c10_complex_math::asin;
  315. using c10_complex_math::asinh;
  316. using c10_complex_math::atan;
  317. using c10_complex_math::atanh;
  318. using c10_complex_math::cos;
  319. using c10_complex_math::cosh;
  320. using c10_complex_math::exp;
  321. using c10_complex_math::log;
  322. using c10_complex_math::log10;
  323. using c10_complex_math::log1p;
  324. using c10_complex_math::log2;
  325. using c10_complex_math::pow;
  326. using c10_complex_math::sin;
  327. using c10_complex_math::sinh;
  328. using c10_complex_math::sqrt;
  329. using c10_complex_math::tan;
  330. using c10_complex_math::tanh;
  331. } // namespace std