vml.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #pragma once
  2. #include <ATen/Config.h>
  3. #include <ATen/Parallel.h>
  4. #include <ATen/OpMathType.h>
  5. #include <ATen/cpu/vec/functional.h>
  6. #include <ATen/cpu/vec/vec.h>
  7. #include <c10/util/complex.h>
  8. // This header implements various unary operations using a MKL VML style
  9. // interface.
  10. // It implements various functions with a simple interface
  11. // For example it enables the user to call vsin(float* out, const float* in,
  12. // size) This functions takes a pointer to a contious output array of floats and
  13. // a constant input array. It will then apply sin to each value in the input
  14. // array and write the result into the output array. out and in may point to the
  15. // same memory, i.e. this fully supports in-place operations. These functions
  16. // also implement their own parallelization, so take precautions when calling
  17. // these from threaded functions.
  18. // When MKL is available it will call into MKL's VML library similar to NumPy
  19. // If MKL is not available it will use SLEEF.
  20. // This file might be compiled under AVX or AVX2 when called from e.g.
  21. // UnaryOpsKernel.cpp
  22. #include <algorithm>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <cstring>
  26. #include <type_traits>
  27. #if AT_MKL_ENABLED() && !defined(__APPLE__)
  28. #include <mkl.h>
  29. #endif
  30. namespace at {
  31. namespace vml {
  32. inline namespace CPU_CAPABILITY {
  33. using namespace vec;
  34. template <typename scalar_t>
  35. inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
  36. parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
  37. map(
  38. [](const Vectorized<scalar_t>& x) {
  39. return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
  40. },
  41. out + begin,
  42. in + begin,
  43. end - begin);
  44. });
  45. }
  46. // NB: We ignore numerical errors by convention and leave them to the user
  47. #define IMPLEMENT_VML(op) \
  48. template <typename scalar_t> \
  49. inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
  50. using vec_t = Vectorized<vec_scalar_t<scalar_t>>; \
  51. vec::map([](vec_t x) { return x.op(); }, out, in, size); \
  52. } \
  53. IMPLEMENT_VML(abs)
  54. IMPLEMENT_VML(acos)
  55. IMPLEMENT_VML(asin)
  56. IMPLEMENT_VML(atan)
  57. IMPLEMENT_VML(ceil)
  58. IMPLEMENT_VML(cos)
  59. // IMPLEMENT_VML(cosh)
  60. IMPLEMENT_VML(erf)
  61. IMPLEMENT_VML(erfc)
  62. IMPLEMENT_VML(erfinv)
  63. IMPLEMENT_VML(exp)
  64. IMPLEMENT_VML(expm1)
  65. IMPLEMENT_VML(floor)
  66. IMPLEMENT_VML(i0)
  67. IMPLEMENT_VML(i0e)
  68. IMPLEMENT_VML(reciprocal)
  69. IMPLEMENT_VML(log)
  70. IMPLEMENT_VML(log10)
  71. IMPLEMENT_VML(log1p)
  72. IMPLEMENT_VML(log2)
  73. IMPLEMENT_VML(neg)
  74. IMPLEMENT_VML(sin)
  75. // IMPLEMENT_VML(sinh)
  76. IMPLEMENT_VML(sqrt)
  77. IMPLEMENT_VML(round)
  78. IMPLEMENT_VML(rsqrt)
  79. IMPLEMENT_VML(tan)
  80. IMPLEMENT_VML(tanh)
  81. IMPLEMENT_VML(trunc)
  82. IMPLEMENT_VML(lgamma)
  83. #if AT_MKL_ENABLED() && !defined(__APPLE__)
  84. // NB: LP64 MKL is the most commonly used and thus we assume it here. That means
  85. // we need to expect MKL_INT to be of type int, which implies int32_t in most
  86. // cases.
  87. static_assert(
  88. std::is_same<MKL_INT, int32_t>::value,
  89. "MKL_INT is assumed to be int32_t");
  90. #define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
  91. template <> \
  92. inline void v##op(type * out, const type * in, int64_t size) { \
  93. int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
  94. if (size <= static_cast<int64_t>(max_mkl_ind)) { \
  95. vm##mkltype##mklop( \
  96. size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
  97. } else { \
  98. MKL_INT ind = 0; \
  99. int64_t chunks = size / max_mkl_ind; \
  100. int64_t rest = size % max_mkl_ind; \
  101. for (; ind < chunks; ind++) { \
  102. vm##mkltype##mklop( \
  103. max_mkl_ind, \
  104. in + ind * max_mkl_ind, \
  105. out + ind * max_mkl_ind, \
  106. VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
  107. } \
  108. vm##mkltype##mklop( \
  109. rest, \
  110. in + ind * max_mkl_ind, \
  111. out + ind * max_mkl_ind, \
  112. VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
  113. } \
  114. }
  115. #define IMPLEMENT_VML_MKL(op, mklop) \
  116. IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
  117. IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
  118. // NB: abs, cosh and sinh were temporarily disabled due to issues with Apple
  119. // NB: expm1 is disabled because on some configs it produces expm1(nan)=-1
  120. IMPLEMENT_VML_MKL(acos, Acos)
  121. IMPLEMENT_VML_MKL(asin, Asin)
  122. IMPLEMENT_VML_MKL(atan, Atan)
  123. IMPLEMENT_VML_MKL(cos, Cos)
  124. // IMPLEMENT_VML_MKL(cosh, Cosh)
  125. IMPLEMENT_VML_MKL(erf, Erf)
  126. IMPLEMENT_VML_MKL(erfc, Erfc)
  127. IMPLEMENT_VML_MKL(erfinv, ErfInv)
  128. IMPLEMENT_VML_MKL(exp, Exp)
  129. // IMPLEMENT_VML_MKL(expm1, Expm1)
  130. IMPLEMENT_VML_MKL(log, Ln)
  131. IMPLEMENT_VML_MKL(log10, Log10)
  132. IMPLEMENT_VML_MKL(sin, Sin)
  133. // IMPLEMENT_VML_MKL(sinh, Sinh)
  134. IMPLEMENT_VML_MKL(sqrt, Sqrt)
  135. IMPLEMENT_VML_MKL(tan, Tan)
  136. IMPLEMENT_VML_MKL(tanh, Tanh)
  137. IMPLEMENT_VML_MKL(trunc, Trunc)
  138. // Not vectorized in MKL version tested
  139. // IMPLEMENT_VML_MKL(abs, Abs)
  140. // IMPLEMENT_VML_MKL(log1p, Log1p)
  141. #if INTEL_MKL_VERSION >= 20180406
  142. IMPLEMENT_VML_MKL(log2, Log2)
  143. #endif
  144. #endif
  145. } // namespace
  146. } // namespace vml
  147. } // namespace at