complex.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. #pragma once
  2. #include <complex>
  3. #include <c10/macros/Macros.h>
  4. #if defined(__CUDACC__) || defined(__HIPCC__)
  5. #include <thrust/complex.h>
  6. #endif
  7. C10_CLANG_DIAGNOSTIC_PUSH()
  8. #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
  9. C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
  10. #endif
  11. #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
  12. C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
  13. #endif
  14. namespace c10 {
  15. // c10::complex is an implementation of complex numbers that aims
  16. // to work on all devices supported by PyTorch
  17. //
  18. // Most of the APIs duplicates std::complex
  19. // Reference: https://en.cppreference.com/w/cpp/numeric/complex
  20. //
  21. // [NOTE: Complex Operator Unification]
  22. // Operators currently use a mix of std::complex, thrust::complex, and
  23. // c10::complex internally. The end state is that all operators will use
  24. // c10::complex internally. Until then, there may be some hacks to support all
  25. // variants.
  26. //
  27. //
  28. // [Note on Constructors]
  29. //
  30. // The APIs of constructors are mostly copied from C++ standard:
  31. // https://en.cppreference.com/w/cpp/numeric/complex/complex
  32. //
  33. // Since C++14, all constructors are constexpr in std::complex
  34. //
  35. // There are three types of constructors:
  36. // - initializing from real and imag:
  37. // `constexpr complex( const T& re = T(), const T& im = T() );`
  38. // - implicitly-declared copy constructor
  39. // - converting constructors
  40. //
  41. // Converting constructors:
  42. // - std::complex defines converting constructor between float/double/long
  43. // double,
  44. // while we define converting constructor between float/double.
  45. // - For these converting constructors, upcasting is implicit, downcasting is
  46. // explicit.
  47. // - We also define explicit casting from std::complex/thrust::complex
  48. // - Note that the conversion from thrust is not constexpr, because
  49. // thrust does not define them as constexpr ????
  50. //
  51. //
  52. // [Operator =]
  53. //
  54. // The APIs of operator = are mostly copied from C++ standard:
  55. // https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
  56. //
  57. // Since C++20, all operator= are constexpr. Although we are not building with
  58. // C++20, we also obey this behavior.
  59. //
  60. // There are three types of assign operator:
  61. // - Assign a real value from the same scalar type
  62. // - In std, this is templated as complex& operator=(const T& x)
  63. // with specialization `complex& operator=(T x)` for float/double/long
  64. // double Since we only support float and double, on will use `complex&
  65. // operator=(T x)`
  66. // - Copy assignment operator and converting assignment operator
  67. // - There is no specialization of converting assignment operators, which type
  68. // is
  69. // convertible is solely dependent on whether the scalar type is convertible
  70. //
  71. // In addition to the standard assignment, we also provide assignment operators
  72. // with std and thrust
  73. //
  74. //
  75. // [Casting operators]
  76. //
  77. // std::complex does not have casting operators. We define casting operators
  78. // casting to std::complex and thrust::complex
  79. //
  80. //
  81. // [Operator ""]
  82. //
  83. // std::complex has custom literals `i`, `if` and `il` defined in namespace
  84. // `std::literals::complex_literals`. We define our own custom literals in the
  85. // namespace `c10::complex_literals`. Our custom literals does not follow the
  86. // same behavior as in std::complex, instead, we define _if, _id to construct
  87. // float/double complex literals.
  88. //
  89. //
  90. // [real() and imag()]
  91. //
  92. // In C++20, there are two overload of these functions, one it to return the
  93. // real/imag, another is to set real/imag, they are both constexpr. We follow
  94. // this design.
  95. //
  96. //
  97. // [Operator +=,-=,*=,/=]
  98. //
  99. // Since C++20, these operators become constexpr. In our implementation, they
  100. // are also constexpr.
  101. //
  102. // There are two types of such operators: operating with a real number, or
  103. // operating with another complex number. For the operating with a real number,
  104. // the generic template form has argument type `const T &`, while the overload
  105. // for float/double/long double has `T`. We will follow the same type as
  106. // float/double/long double in std.
  107. //
  108. // [Unary operator +-]
  109. //
  110. // Since C++20, they are constexpr. We also make them expr
  111. //
  112. // [Binary operators +-*/]
  113. //
  114. // Each operator has three versions (taking + as example):
  115. // - complex + complex
  116. // - complex + real
  117. // - real + complex
  118. //
  119. // [Operator ==, !=]
  120. //
  121. // Each operator has three versions (taking == as example):
  122. // - complex == complex
  123. // - complex == real
  124. // - real == complex
  125. //
  126. // Some of them are removed on C++20, but we decide to keep them
  127. //
  128. // [Operator <<, >>]
  129. //
  130. // These are implemented by casting to std::complex
  131. //
  132. //
  133. //
  134. // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
  135. // because:
  136. // - lots of members and functions of c10::Half are not constexpr
  137. // - thrust::complex only support float and double
  138. template <typename T>
  139. struct alignas(sizeof(T) * 2) complex {
  140. using value_type = T;
  141. T real_ = T(0);
  142. T imag_ = T(0);
  143. constexpr complex() = default;
  144. C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
  145. : real_(re), imag_(im) {}
  146. template <typename U>
  147. explicit constexpr complex(const std::complex<U>& other)
  148. : complex(other.real(), other.imag()) {}
  149. #if defined(__CUDACC__) || defined(__HIPCC__)
  150. template <typename U>
  151. explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
  152. : real_(other.real()), imag_(other.imag()) {}
  153. // NOTE can not be implemented as follow due to ROCm bug:
  154. // explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
  155. // complex(other.real(), other.imag()) {}
  156. #endif
  157. // Use SFINAE to specialize casting constructor for c10::complex<float> and
  158. // c10::complex<double>
  159. template <typename U = T>
  160. C10_HOST_DEVICE explicit constexpr complex(
  161. const std::enable_if_t<std::is_same<U, float>::value, complex<double>>&
  162. other)
  163. : real_(other.real_), imag_(other.imag_) {}
  164. template <typename U = T>
  165. C10_HOST_DEVICE constexpr complex(
  166. const std::enable_if_t<std::is_same<U, double>::value, complex<float>>&
  167. other)
  168. : real_(other.real_), imag_(other.imag_) {}
  169. constexpr complex<T>& operator=(T re) {
  170. real_ = re;
  171. imag_ = 0;
  172. return *this;
  173. }
  174. constexpr complex<T>& operator+=(T re) {
  175. real_ += re;
  176. return *this;
  177. }
  178. constexpr complex<T>& operator-=(T re) {
  179. real_ -= re;
  180. return *this;
  181. }
  182. constexpr complex<T>& operator*=(T re) {
  183. real_ *= re;
  184. imag_ *= re;
  185. return *this;
  186. }
  187. constexpr complex<T>& operator/=(T re) {
  188. real_ /= re;
  189. imag_ /= re;
  190. return *this;
  191. }
  192. template <typename U>
  193. constexpr complex<T>& operator=(const complex<U>& rhs) {
  194. real_ = rhs.real();
  195. imag_ = rhs.imag();
  196. return *this;
  197. }
  198. template <typename U>
  199. constexpr complex<T>& operator+=(const complex<U>& rhs) {
  200. real_ += rhs.real();
  201. imag_ += rhs.imag();
  202. return *this;
  203. }
  204. template <typename U>
  205. constexpr complex<T>& operator-=(const complex<U>& rhs) {
  206. real_ -= rhs.real();
  207. imag_ -= rhs.imag();
  208. return *this;
  209. }
  210. template <typename U>
  211. constexpr complex<T>& operator*=(const complex<U>& rhs) {
  212. // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
  213. T a = real_;
  214. T b = imag_;
  215. U c = rhs.real();
  216. U d = rhs.imag();
  217. real_ = a * c - b * d;
  218. imag_ = a * d + b * c;
  219. return *this;
  220. }
  221. #ifdef __APPLE__
  222. #define FORCE_INLINE_APPLE __attribute__((always_inline))
  223. #else
  224. #define FORCE_INLINE_APPLE
  225. #endif
  226. template <typename U>
  227. constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
  228. __ubsan_ignore_float_divide_by_zero__ {
  229. // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
  230. // the calculation below follows numpy's complex division
  231. T ar = real_;
  232. T ai = imag_;
  233. U br = rhs.real();
  234. U bi = rhs.imag();
  235. #if defined(__GNUC__) && !defined(__clang__)
  236. // std::abs is already constexpr by gcc
  237. auto abs_br = std::abs(br);
  238. auto abs_bi = std::abs(bi);
  239. #else
  240. auto abs_br = br < 0 ? -br : br;
  241. auto abs_bi = bi < 0 ? -bi : bi;
  242. #endif
  243. if (abs_br >= abs_bi) {
  244. if (abs_br == 0 && abs_bi == 0) {
  245. /* divide by zeros should yield a complex inf or nan */
  246. real_ = ar / abs_br;
  247. imag_ = ai / abs_bi;
  248. } else {
  249. auto rat = bi / br;
  250. auto scl = 1.0 / (br + bi * rat);
  251. real_ = (ar + ai * rat) * scl;
  252. imag_ = (ai - ar * rat) * scl;
  253. }
  254. } else {
  255. auto rat = br / bi;
  256. auto scl = 1.0 / (bi + br * rat);
  257. real_ = (ar * rat + ai) * scl;
  258. imag_ = (ai * rat - ar) * scl;
  259. }
  260. return *this;
  261. }
  262. #undef FORCE_INLINE_APPLE
  263. template <typename U>
  264. constexpr complex<T>& operator=(const std::complex<U>& rhs) {
  265. real_ = rhs.real();
  266. imag_ = rhs.imag();
  267. return *this;
  268. }
  269. #if defined(__CUDACC__) || defined(__HIPCC__)
  270. template <typename U>
  271. C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
  272. real_ = rhs.real();
  273. imag_ = rhs.imag();
  274. return *this;
  275. }
  276. #endif
  277. template <typename U>
  278. explicit constexpr operator std::complex<U>() const {
  279. return std::complex<U>(std::complex<T>(real(), imag()));
  280. }
  281. #if defined(__CUDACC__) || defined(__HIPCC__)
  282. template <typename U>
  283. C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
  284. return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
  285. }
  286. #endif
  287. // consistent with NumPy behavior
  288. explicit constexpr operator bool() const {
  289. return real() || imag();
  290. }
  291. C10_HOST_DEVICE constexpr T real() const {
  292. return real_;
  293. }
  294. constexpr void real(T value) {
  295. real_ = value;
  296. }
  297. constexpr T imag() const {
  298. return imag_;
  299. }
  300. constexpr void imag(T value) {
  301. imag_ = value;
  302. }
  303. };
  304. namespace complex_literals {
  305. constexpr complex<float> operator"" _if(long double imag) {
  306. return complex<float>(0.0f, static_cast<float>(imag));
  307. }
  308. constexpr complex<double> operator"" _id(long double imag) {
  309. return complex<double>(0.0, static_cast<double>(imag));
  310. }
  311. constexpr complex<float> operator"" _if(unsigned long long imag) {
  312. return complex<float>(0.0f, static_cast<float>(imag));
  313. }
  314. constexpr complex<double> operator"" _id(unsigned long long imag) {
  315. return complex<double>(0.0, static_cast<double>(imag));
  316. }
  317. } // namespace complex_literals
  318. template <typename T>
  319. constexpr complex<T> operator+(const complex<T>& val) {
  320. return val;
  321. }
  322. template <typename T>
  323. constexpr complex<T> operator-(const complex<T>& val) {
  324. return complex<T>(-val.real(), -val.imag());
  325. }
  326. template <typename T>
  327. constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
  328. complex<T> result = lhs;
  329. return result += rhs;
  330. }
  331. template <typename T>
  332. constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
  333. complex<T> result = lhs;
  334. return result += rhs;
  335. }
  336. template <typename T>
  337. constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
  338. return complex<T>(lhs + rhs.real(), rhs.imag());
  339. }
  340. template <typename T>
  341. constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
  342. complex<T> result = lhs;
  343. return result -= rhs;
  344. }
  345. template <typename T>
  346. constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
  347. complex<T> result = lhs;
  348. return result -= rhs;
  349. }
  350. template <typename T>
  351. constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
  352. complex<T> result = -rhs;
  353. return result += lhs;
  354. }
  355. template <typename T>
  356. constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
  357. complex<T> result = lhs;
  358. return result *= rhs;
  359. }
  360. template <typename T>
  361. constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
  362. complex<T> result = lhs;
  363. return result *= rhs;
  364. }
  365. template <typename T>
  366. constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
  367. complex<T> result = rhs;
  368. return result *= lhs;
  369. }
  370. template <typename T>
  371. constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
  372. complex<T> result = lhs;
  373. return result /= rhs;
  374. }
  375. template <typename T>
  376. constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
  377. complex<T> result = lhs;
  378. return result /= rhs;
  379. }
  380. template <typename T>
  381. constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
  382. complex<T> result(lhs, T());
  383. return result /= rhs;
  384. }
  385. // Define operators between integral scalars and c10::complex. std::complex does
  386. // not support this when T is a floating-point number. This is useful because it
  387. // saves a lot of "static_cast" when operate a complex and an integer. This
  388. // makes the code both less verbose and potentially more efficient.
  389. #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
  390. typename std::enable_if_t< \
  391. std::is_floating_point<fT>::value && std::is_integral<iT>::value, \
  392. int> = 0
  393. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  394. constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
  395. return a + static_cast<fT>(b);
  396. }
  397. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  398. constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
  399. return static_cast<fT>(a) + b;
  400. }
  401. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  402. constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
  403. return a - static_cast<fT>(b);
  404. }
  405. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  406. constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
  407. return static_cast<fT>(a) - b;
  408. }
  409. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  410. constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
  411. return a * static_cast<fT>(b);
  412. }
  413. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  414. constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
  415. return static_cast<fT>(a) * b;
  416. }
  417. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  418. constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
  419. return a / static_cast<fT>(b);
  420. }
  421. template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
  422. constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
  423. return static_cast<fT>(a) / b;
  424. }
  425. #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
  426. template <typename T>
  427. constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
  428. return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
  429. }
  430. template <typename T>
  431. constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
  432. return (lhs.real() == rhs) && (lhs.imag() == T());
  433. }
  434. template <typename T>
  435. constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
  436. return (lhs == rhs.real()) && (T() == rhs.imag());
  437. }
  438. template <typename T>
  439. constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
  440. return !(lhs == rhs);
  441. }
  442. template <typename T>
  443. constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
  444. return !(lhs == rhs);
  445. }
  446. template <typename T>
  447. constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
  448. return !(lhs == rhs);
  449. }
  450. template <typename T, typename CharT, typename Traits>
  451. std::basic_ostream<CharT, Traits>& operator<<(
  452. std::basic_ostream<CharT, Traits>& os,
  453. const complex<T>& x) {
  454. return (os << static_cast<std::complex<T>>(x));
  455. }
  456. template <typename T, typename CharT, typename Traits>
  457. std::basic_istream<CharT, Traits>& operator>>(
  458. std::basic_istream<CharT, Traits>& is,
  459. complex<T>& x) {
  460. std::complex<T> tmp;
  461. is >> tmp;
  462. x = tmp;
  463. return is;
  464. }
  465. } // namespace c10
  466. // std functions
  467. //
  468. // The implementation of these functions also follow the design of C++20
  469. namespace std {
  470. template <typename T>
  471. constexpr T real(const c10::complex<T>& z) {
  472. return z.real();
  473. }
  474. template <typename T>
  475. constexpr T imag(const c10::complex<T>& z) {
  476. return z.imag();
  477. }
  478. template <typename T>
  479. C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
  480. #if defined(__CUDACC__) || defined(__HIPCC__)
  481. return thrust::abs(static_cast<thrust::complex<T>>(z));
  482. #else
  483. return std::abs(static_cast<std::complex<T>>(z));
  484. #endif
  485. }
  486. #if defined(USE_ROCM)
  487. #define ROCm_Bug(x)
  488. #else
  489. #define ROCm_Bug(x) x
  490. #endif
  491. template <typename T>
  492. C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
  493. return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
  494. }
  495. #undef ROCm_Bug
  496. template <typename T>
  497. constexpr T norm(const c10::complex<T>& z) {
  498. return z.real() * z.real() + z.imag() * z.imag();
  499. }
  500. // For std::conj, there are other versions of it:
  501. // constexpr std::complex<float> conj( float z );
  502. // template< class DoubleOrInteger >
  503. // constexpr std::complex<double> conj( DoubleOrInteger z );
  504. // constexpr std::complex<long double> conj( long double z );
  505. // These are not implemented
  506. // TODO(@zasdfgbnm): implement them as c10::conj
  507. template <typename T>
  508. constexpr c10::complex<T> conj(const c10::complex<T>& z) {
  509. return c10::complex<T>(z.real(), -z.imag());
  510. }
  511. // Thrust does not have complex --> complex version of thrust::proj,
  512. // so this function is not implemented at c10 right now.
  513. // TODO(@zasdfgbnm): implement it by ourselves
  514. // There is no c10 version of std::polar, because std::polar always
  515. // returns std::complex. Use c10::polar instead;
  516. } // namespace std
  517. namespace c10 {
  518. template <typename T>
  519. C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
  520. #if defined(__CUDACC__) || defined(__HIPCC__)
  521. return static_cast<complex<T>>(thrust::polar(r, theta));
  522. #else
  523. // std::polar() requires r >= 0, so spell out the explicit implementation to
  524. // avoid a branch.
  525. return complex<T>(r * std::cos(theta), r * std::sin(theta));
  526. #endif
  527. }
  528. } // namespace c10
  529. C10_CLANG_DIAGNOSTIC_POP()
  530. #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
  531. // math functions are included in a separate file
  532. #include <c10/util/complex_math.h> // IWYU pragma: keep
  533. // utilities for complex types
  534. #include <c10/util/complex_utils.h> // IWYU pragma: keep
  535. #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H