Pow.h 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. namespace c10 {
  4. class Scalar;
  5. }
  6. namespace at {
  7. struct TensorIterator;
  8. struct TensorIteratorBase;
  9. namespace native {
  10. #if defined(__CUDACC__) || defined(__HIPCC__)
  11. #define HOST_DEVICE __host__ __device__
  12. #else
  13. #define HOST_DEVICE
  14. #endif
  15. // integral power in pytorch allows for negative exponents, giving truncated integral results.
  16. // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
  17. // only non-zero result.
  18. template <class T,
  19. typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
  20. static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
  21. T result = 1;
  22. while (b) {
  23. if (b & 1) {
  24. result *= a;
  25. }
  26. b /= 2;
  27. a *= a;
  28. }
  29. return result;
  30. }
  31. template <class T,
  32. typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
  33. static inline HOST_DEVICE T powi(T a, T b) {
  34. return powi_impl(a, b);
  35. }
  36. template <class T,
  37. typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
  38. static inline HOST_DEVICE T powi(T a, T b) {
  39. if ( b < 0 ) {
  40. if ( a == 1 ) {
  41. return 1;
  42. } else if ( a == -1 ) {
  43. auto negative = (-b) % static_cast<T>(2);
  44. return negative ? -1 : 1;
  45. } else {
  46. return 0;
  47. }
  48. }
  49. return powi_impl(a, b);
  50. }
  51. using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
  52. using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
  53. DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
  54. DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
  55. } // namespace native
  56. } // namespace at