#pragma once // Complex number math operations that act as no-ops for other dtypes. #include #include #include #include namespace at { namespace native { inline namespace CPU_CAPABILITY { template inline VALUE_TYPE zabs (SCALAR_TYPE z) { return z; } template<> inline c10::complex zabs > (c10::complex z) { return c10::complex(std::abs(z)); } template<> inline float zabs , float> (c10::complex z) { return std::abs(z); } template<> inline c10::complex zabs > (c10::complex z) { return c10::complex(std::abs(z)); } template<> inline double zabs , double> (c10::complex z) { return std::abs(z); } // This overload corresponds to non-complex dtypes. // The function is consistent with its NumPy equivalent // for non-complex dtypes where `pi` is returned for // negative real numbers and `0` is returned for 0 or positive // real numbers. // Note: `nan` is propagated. template inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { if (at::_isnan(z)) { return z; } return z < 0 ? c10::pi : 0; } template<> inline c10::complex angle_impl > (c10::complex z) { return c10::complex(std::arg(z), 0.0); } template<> inline float angle_impl , float> (c10::complex z) { return std::arg(z); } template<> inline c10::complex angle_impl > (c10::complex z) { return c10::complex(std::arg(z), 0.0); } template<> inline double angle_impl , double> (c10::complex z) { return std::arg(z); } template constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) { return z; //No-Op } template<> constexpr c10::complex real_impl > (c10::complex z) { return c10::complex(z.real(), 0.0); } template<> constexpr float real_impl , float> (c10::complex z) { return z.real(); } template<> constexpr c10::complex real_impl > (c10::complex z) { return c10::complex(z.real(), 0.0); } template<> constexpr double real_impl , double> (c10::complex z) { return z.real(); } template constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) { return 0; } template<> constexpr c10::complex imag_impl > (c10::complex z) { return c10::complex(z.imag(), 0.0); } template<> constexpr float imag_impl , float> (c10::complex z) { return z.imag(); } template<> constexpr c10::complex imag_impl > (c10::complex z) { return c10::complex(z.imag(), 0.0); } template<> constexpr double imag_impl , double> (c10::complex z) { return z.imag(); } template inline TYPE conj_impl (TYPE z) { return z; //No-Op } template<> inline c10::complex conj_impl > (c10::complex z) { return c10::complex{z.real(), -z.imag()}; } template<> inline c10::complex conj_impl > (c10::complex z) { return c10::complex(z.real(), -z.imag()); } template<> inline c10::complex conj_impl > (c10::complex z) { return c10::complex(z.real(), -z.imag()); } template inline TYPE ceil_impl (TYPE z) { return std::ceil(z); } template <> inline c10::complex ceil_impl (c10::complex z) { return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); } template <> inline c10::complex ceil_impl (c10::complex z) { return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); } template inline c10::complex sgn_impl (c10::complex z) { if (z == c10::complex(0, 0)) { return c10::complex(0, 0); } else { return z / zabs(z); } } template inline TYPE floor_impl (TYPE z) { return std::floor(z); } template <> inline c10::complex floor_impl (c10::complex z) { return c10::complex(std::floor(z.real()), std::floor(z.imag())); } template <> inline c10::complex floor_impl (c10::complex z) { return c10::complex(std::floor(z.real()), std::floor(z.imag())); } template inline TYPE round_impl (TYPE z) { return std::nearbyint(z); } template <> inline c10::complex round_impl (c10::complex z) { return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); } template <> inline c10::complex round_impl (c10::complex z) { return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); } template inline TYPE trunc_impl (TYPE z) { return std::trunc(z); } template <> inline c10::complex trunc_impl (c10::complex z) { return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); } template <> inline c10::complex trunc_impl (c10::complex z) { return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); } template ::value, int> = 0> inline TYPE max_impl (TYPE a, TYPE b) { if (_isnan(a) || _isnan(b)) { return std::numeric_limits::quiet_NaN(); } else { return std::max(a, b); } } template ::value, int> = 0> inline TYPE max_impl (TYPE a, TYPE b) { if (_isnan(a)) { return a; } else if (_isnan(b)) { return b; } else { return std::abs(a) > std::abs(b) ? a : b; } } template ::value, int> = 0> inline TYPE min_impl (TYPE a, TYPE b) { if (_isnan(a) || _isnan(b)) { return std::numeric_limits::quiet_NaN(); } else { return std::min(a, b); } } template ::value, int> = 0> inline TYPE min_impl (TYPE a, TYPE b) { if (_isnan(a)) { return a; } else if (_isnan(b)) { return b; } else { return std::abs(a) < std::abs(b) ? a : b; } } } // end namespace }} //end at::native