123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620 |
- #pragma once
- #include <complex>
- #include <c10/macros/Macros.h>
- #if defined(__CUDACC__) || defined(__HIPCC__)
- #include <thrust/complex.h>
- #endif
- C10_CLANG_DIAGNOSTIC_PUSH()
- #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
- C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
- #endif
- #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
- C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
- #endif
- namespace c10 {
- // c10::complex is an implementation of complex numbers that aims
- // to work on all devices supported by PyTorch
- //
- // Most of the APIs duplicates std::complex
- // Reference: https://en.cppreference.com/w/cpp/numeric/complex
- //
- // [NOTE: Complex Operator Unification]
- // Operators currently use a mix of std::complex, thrust::complex, and
- // c10::complex internally. The end state is that all operators will use
- // c10::complex internally. Until then, there may be some hacks to support all
- // variants.
- //
- //
- // [Note on Constructors]
- //
- // The APIs of constructors are mostly copied from C++ standard:
- // https://en.cppreference.com/w/cpp/numeric/complex/complex
- //
- // Since C++14, all constructors are constexpr in std::complex
- //
- // There are three types of constructors:
- // - initializing from real and imag:
- // `constexpr complex( const T& re = T(), const T& im = T() );`
- // - implicitly-declared copy constructor
- // - converting constructors
- //
- // Converting constructors:
- // - std::complex defines converting constructor between float/double/long
- // double,
- // while we define converting constructor between float/double.
- // - For these converting constructors, upcasting is implicit, downcasting is
- // explicit.
- // - We also define explicit casting from std::complex/thrust::complex
- // - Note that the conversion from thrust is not constexpr, because
- // thrust does not define them as constexpr ????
- //
- //
- // [Operator =]
- //
- // The APIs of operator = are mostly copied from C++ standard:
- // https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
- //
- // Since C++20, all operator= are constexpr. Although we are not building with
- // C++20, we also obey this behavior.
- //
- // There are three types of assign operator:
- // - Assign a real value from the same scalar type
- // - In std, this is templated as complex& operator=(const T& x)
- // with specialization `complex& operator=(T x)` for float/double/long
- // double Since we only support float and double, on will use `complex&
- // operator=(T x)`
- // - Copy assignment operator and converting assignment operator
- // - There is no specialization of converting assignment operators, which type
- // is
- // convertible is solely dependent on whether the scalar type is convertible
- //
- // In addition to the standard assignment, we also provide assignment operators
- // with std and thrust
- //
- //
- // [Casting operators]
- //
- // std::complex does not have casting operators. We define casting operators
- // casting to std::complex and thrust::complex
- //
- //
- // [Operator ""]
- //
- // std::complex has custom literals `i`, `if` and `il` defined in namespace
- // `std::literals::complex_literals`. We define our own custom literals in the
- // namespace `c10::complex_literals`. Our custom literals does not follow the
- // same behavior as in std::complex, instead, we define _if, _id to construct
- // float/double complex literals.
- //
- //
- // [real() and imag()]
- //
- // In C++20, there are two overload of these functions, one it to return the
- // real/imag, another is to set real/imag, they are both constexpr. We follow
- // this design.
- //
- //
- // [Operator +=,-=,*=,/=]
- //
- // Since C++20, these operators become constexpr. In our implementation, they
- // are also constexpr.
- //
- // There are two types of such operators: operating with a real number, or
- // operating with another complex number. For the operating with a real number,
- // the generic template form has argument type `const T &`, while the overload
- // for float/double/long double has `T`. We will follow the same type as
- // float/double/long double in std.
- //
- // [Unary operator +-]
- //
- // Since C++20, they are constexpr. We also make them expr
- //
- // [Binary operators +-*/]
- //
- // Each operator has three versions (taking + as example):
- // - complex + complex
- // - complex + real
- // - real + complex
- //
- // [Operator ==, !=]
- //
- // Each operator has three versions (taking == as example):
- // - complex == complex
- // - complex == real
- // - real == complex
- //
- // Some of them are removed on C++20, but we decide to keep them
- //
- // [Operator <<, >>]
- //
- // These are implemented by casting to std::complex
- //
- //
- //
- // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
- // because:
- // - lots of members and functions of c10::Half are not constexpr
- // - thrust::complex only support float and double
- template <typename T>
- struct alignas(sizeof(T) * 2) complex {
- using value_type = T;
- T real_ = T(0);
- T imag_ = T(0);
- constexpr complex() = default;
- C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
- : real_(re), imag_(im) {}
- template <typename U>
- explicit constexpr complex(const std::complex<U>& other)
- : complex(other.real(), other.imag()) {}
- #if defined(__CUDACC__) || defined(__HIPCC__)
- template <typename U>
- explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
- : real_(other.real()), imag_(other.imag()) {}
- // NOTE can not be implemented as follow due to ROCm bug:
- // explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
- // complex(other.real(), other.imag()) {}
- #endif
- // Use SFINAE to specialize casting constructor for c10::complex<float> and
- // c10::complex<double>
- template <typename U = T>
- C10_HOST_DEVICE explicit constexpr complex(
- const std::enable_if_t<std::is_same<U, float>::value, complex<double>>&
- other)
- : real_(other.real_), imag_(other.imag_) {}
- template <typename U = T>
- C10_HOST_DEVICE constexpr complex(
- const std::enable_if_t<std::is_same<U, double>::value, complex<float>>&
- other)
- : real_(other.real_), imag_(other.imag_) {}
- constexpr complex<T>& operator=(T re) {
- real_ = re;
- imag_ = 0;
- return *this;
- }
- constexpr complex<T>& operator+=(T re) {
- real_ += re;
- return *this;
- }
- constexpr complex<T>& operator-=(T re) {
- real_ -= re;
- return *this;
- }
- constexpr complex<T>& operator*=(T re) {
- real_ *= re;
- imag_ *= re;
- return *this;
- }
- constexpr complex<T>& operator/=(T re) {
- real_ /= re;
- imag_ /= re;
- return *this;
- }
- template <typename U>
- constexpr complex<T>& operator=(const complex<U>& rhs) {
- real_ = rhs.real();
- imag_ = rhs.imag();
- return *this;
- }
- template <typename U>
- constexpr complex<T>& operator+=(const complex<U>& rhs) {
- real_ += rhs.real();
- imag_ += rhs.imag();
- return *this;
- }
- template <typename U>
- constexpr complex<T>& operator-=(const complex<U>& rhs) {
- real_ -= rhs.real();
- imag_ -= rhs.imag();
- return *this;
- }
- template <typename U>
- constexpr complex<T>& operator*=(const complex<U>& rhs) {
- // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
- T a = real_;
- T b = imag_;
- U c = rhs.real();
- U d = rhs.imag();
- real_ = a * c - b * d;
- imag_ = a * d + b * c;
- return *this;
- }
- #ifdef __APPLE__
- #define FORCE_INLINE_APPLE __attribute__((always_inline))
- #else
- #define FORCE_INLINE_APPLE
- #endif
- template <typename U>
- constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
- __ubsan_ignore_float_divide_by_zero__ {
- // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
- // the calculation below follows numpy's complex division
- T ar = real_;
- T ai = imag_;
- U br = rhs.real();
- U bi = rhs.imag();
- #if defined(__GNUC__) && !defined(__clang__)
- // std::abs is already constexpr by gcc
- auto abs_br = std::abs(br);
- auto abs_bi = std::abs(bi);
- #else
- auto abs_br = br < 0 ? -br : br;
- auto abs_bi = bi < 0 ? -bi : bi;
- #endif
- if (abs_br >= abs_bi) {
- if (abs_br == 0 && abs_bi == 0) {
- /* divide by zeros should yield a complex inf or nan */
- real_ = ar / abs_br;
- imag_ = ai / abs_bi;
- } else {
- auto rat = bi / br;
- auto scl = 1.0 / (br + bi * rat);
- real_ = (ar + ai * rat) * scl;
- imag_ = (ai - ar * rat) * scl;
- }
- } else {
- auto rat = br / bi;
- auto scl = 1.0 / (bi + br * rat);
- real_ = (ar * rat + ai) * scl;
- imag_ = (ai * rat - ar) * scl;
- }
- return *this;
- }
- #undef FORCE_INLINE_APPLE
- template <typename U>
- constexpr complex<T>& operator=(const std::complex<U>& rhs) {
- real_ = rhs.real();
- imag_ = rhs.imag();
- return *this;
- }
- #if defined(__CUDACC__) || defined(__HIPCC__)
- template <typename U>
- C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
- real_ = rhs.real();
- imag_ = rhs.imag();
- return *this;
- }
- #endif
- template <typename U>
- explicit constexpr operator std::complex<U>() const {
- return std::complex<U>(std::complex<T>(real(), imag()));
- }
- #if defined(__CUDACC__) || defined(__HIPCC__)
- template <typename U>
- C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
- return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
- }
- #endif
- // consistent with NumPy behavior
- explicit constexpr operator bool() const {
- return real() || imag();
- }
- C10_HOST_DEVICE constexpr T real() const {
- return real_;
- }
- constexpr void real(T value) {
- real_ = value;
- }
- constexpr T imag() const {
- return imag_;
- }
- constexpr void imag(T value) {
- imag_ = value;
- }
- };
- namespace complex_literals {
- constexpr complex<float> operator"" _if(long double imag) {
- return complex<float>(0.0f, static_cast<float>(imag));
- }
- constexpr complex<double> operator"" _id(long double imag) {
- return complex<double>(0.0, static_cast<double>(imag));
- }
- constexpr complex<float> operator"" _if(unsigned long long imag) {
- return complex<float>(0.0f, static_cast<float>(imag));
- }
- constexpr complex<double> operator"" _id(unsigned long long imag) {
- return complex<double>(0.0, static_cast<double>(imag));
- }
- } // namespace complex_literals
- template <typename T>
- constexpr complex<T> operator+(const complex<T>& val) {
- return val;
- }
- template <typename T>
- constexpr complex<T> operator-(const complex<T>& val) {
- return complex<T>(-val.real(), -val.imag());
- }
- template <typename T>
- constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
- complex<T> result = lhs;
- return result += rhs;
- }
- template <typename T>
- constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
- complex<T> result = lhs;
- return result += rhs;
- }
- template <typename T>
- constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
- return complex<T>(lhs + rhs.real(), rhs.imag());
- }
- template <typename T>
- constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
- complex<T> result = lhs;
- return result -= rhs;
- }
- template <typename T>
- constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
- complex<T> result = lhs;
- return result -= rhs;
- }
- template <typename T>
- constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
- complex<T> result = -rhs;
- return result += lhs;
- }
- template <typename T>
- constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
- complex<T> result = lhs;
- return result *= rhs;
- }
- template <typename T>
- constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
- complex<T> result = lhs;
- return result *= rhs;
- }
- template <typename T>
- constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
- complex<T> result = rhs;
- return result *= lhs;
- }
- template <typename T>
- constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
- complex<T> result = lhs;
- return result /= rhs;
- }
- template <typename T>
- constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
- complex<T> result = lhs;
- return result /= rhs;
- }
- template <typename T>
- constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
- complex<T> result(lhs, T());
- return result /= rhs;
- }
- // Define operators between integral scalars and c10::complex. std::complex does
- // not support this when T is a floating-point number. This is useful because it
- // saves a lot of "static_cast" when operate a complex and an integer. This
- // makes the code both less verbose and potentially more efficient.
- #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
- typename std::enable_if_t< \
- std::is_floating_point<fT>::value && std::is_integral<iT>::value, \
- int> = 0
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
- return a + static_cast<fT>(b);
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
- return static_cast<fT>(a) + b;
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
- return a - static_cast<fT>(b);
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
- return static_cast<fT>(a) - b;
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
- return a * static_cast<fT>(b);
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
- return static_cast<fT>(a) * b;
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
- return a / static_cast<fT>(b);
- }
- template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
- constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
- return static_cast<fT>(a) / b;
- }
- #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
- template <typename T>
- constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
- return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
- }
- template <typename T>
- constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
- return (lhs.real() == rhs) && (lhs.imag() == T());
- }
- template <typename T>
- constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
- return (lhs == rhs.real()) && (T() == rhs.imag());
- }
- template <typename T>
- constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
- return !(lhs == rhs);
- }
- template <typename T>
- constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
- return !(lhs == rhs);
- }
- template <typename T>
- constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
- return !(lhs == rhs);
- }
- template <typename T, typename CharT, typename Traits>
- std::basic_ostream<CharT, Traits>& operator<<(
- std::basic_ostream<CharT, Traits>& os,
- const complex<T>& x) {
- return (os << static_cast<std::complex<T>>(x));
- }
- template <typename T, typename CharT, typename Traits>
- std::basic_istream<CharT, Traits>& operator>>(
- std::basic_istream<CharT, Traits>& is,
- complex<T>& x) {
- std::complex<T> tmp;
- is >> tmp;
- x = tmp;
- return is;
- }
- } // namespace c10
- // std functions
- //
- // The implementation of these functions also follow the design of C++20
- namespace std {
- template <typename T>
- constexpr T real(const c10::complex<T>& z) {
- return z.real();
- }
- template <typename T>
- constexpr T imag(const c10::complex<T>& z) {
- return z.imag();
- }
- template <typename T>
- C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
- #if defined(__CUDACC__) || defined(__HIPCC__)
- return thrust::abs(static_cast<thrust::complex<T>>(z));
- #else
- return std::abs(static_cast<std::complex<T>>(z));
- #endif
- }
- #if defined(USE_ROCM)
- #define ROCm_Bug(x)
- #else
- #define ROCm_Bug(x) x
- #endif
- template <typename T>
- C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
- return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
- }
- #undef ROCm_Bug
- template <typename T>
- constexpr T norm(const c10::complex<T>& z) {
- return z.real() * z.real() + z.imag() * z.imag();
- }
- // For std::conj, there are other versions of it:
- // constexpr std::complex<float> conj( float z );
- // template< class DoubleOrInteger >
- // constexpr std::complex<double> conj( DoubleOrInteger z );
- // constexpr std::complex<long double> conj( long double z );
- // These are not implemented
- // TODO(@zasdfgbnm): implement them as c10::conj
- template <typename T>
- constexpr c10::complex<T> conj(const c10::complex<T>& z) {
- return c10::complex<T>(z.real(), -z.imag());
- }
- // Thrust does not have complex --> complex version of thrust::proj,
- // so this function is not implemented at c10 right now.
- // TODO(@zasdfgbnm): implement it by ourselves
- // There is no c10 version of std::polar, because std::polar always
- // returns std::complex. Use c10::polar instead;
- } // namespace std
- namespace c10 {
- template <typename T>
- C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
- #if defined(__CUDACC__) || defined(__HIPCC__)
- return static_cast<complex<T>>(thrust::polar(r, theta));
- #else
- // std::polar() requires r >= 0, so spell out the explicit implementation to
- // avoid a branch.
- return complex<T>(r * std::cos(theta), r * std::sin(theta));
- #endif
- }
- } // namespace c10
- C10_CLANG_DIAGNOSTIC_POP()
- #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
- // math functions are included in a separate file
- #include <c10/util/complex_math.h> // IWYU pragma: keep
- // utilities for complex types
- #include <c10/util/complex_utils.h> // IWYU pragma: keep
- #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
|