CPUBlas.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #pragma once
  2. #include <ATen/OpMathType.h>
  3. #include <ATen/native/DispatchStub.h>
  4. #include <ATen/native/TransposeType.h>
  5. #include <c10/util/complex.h>
  6. #include <c10/core/ScalarType.h>
  7. #include <c10/core/Scalar.h>
  8. namespace at {
  9. namespace native {
  10. namespace cpublas {
  11. namespace internal {
  12. void normalize_last_dims(
  13. TransposeType transa, TransposeType transb,
  14. int64_t m, int64_t n, int64_t k,
  15. int64_t *lda, int64_t *ldb, int64_t *ldc);
  16. } // namespace internal
  17. using gemm_fn = void(*)(
  18. at::ScalarType type,
  19. TransposeType transa, TransposeType transb,
  20. int64_t m, int64_t n, int64_t k,
  21. const Scalar& alpha,
  22. const void *a, int64_t lda,
  23. const void *b, int64_t ldb,
  24. const Scalar& beta,
  25. void *c, int64_t ldc);
  26. DECLARE_DISPATCH(gemm_fn, gemm_stub);
  27. template <typename scalar_t>
  28. void gemm(
  29. TransposeType transa, TransposeType transb,
  30. int64_t m, int64_t n, int64_t k,
  31. at::opmath_type<scalar_t> alpha,
  32. const scalar_t *a, int64_t lda,
  33. const scalar_t *b, int64_t ldb,
  34. at::opmath_type<scalar_t> beta,
  35. scalar_t *c, int64_t ldc) {
  36. internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
  37. gemm_stub(
  38. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  39. transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  40. }
  41. void gemm(
  42. TransposeType transa, TransposeType transb,
  43. int64_t m, int64_t n, int64_t k,
  44. double alpha,
  45. const double *a, int64_t lda,
  46. const double *b, int64_t ldb,
  47. double beta,
  48. double *c, int64_t ldc);
  49. void gemm(
  50. TransposeType transa, TransposeType transb,
  51. int64_t m, int64_t n, int64_t k,
  52. float alpha,
  53. const float *a, int64_t lda,
  54. const float *b, int64_t ldb,
  55. float beta,
  56. float *c, int64_t ldc);
  57. void gemm(
  58. TransposeType transa, TransposeType transb,
  59. int64_t m, int64_t n, int64_t k,
  60. float alpha,
  61. const at::BFloat16 *a, int64_t lda,
  62. const at::BFloat16 *b, int64_t ldb,
  63. float beta,
  64. at::BFloat16 *c, int64_t ldc);
  65. void gemm(
  66. TransposeType transa, TransposeType transb,
  67. int64_t m, int64_t n, int64_t k,
  68. c10::complex<double> alpha,
  69. const c10::complex<double> *a, int64_t lda,
  70. const c10::complex<double> *b, int64_t ldb,
  71. c10::complex<double> beta,
  72. c10::complex<double> *c, int64_t ldc);
  73. void gemm(
  74. TransposeType transa, TransposeType transb,
  75. int64_t m, int64_t n, int64_t k,
  76. c10::complex<float> alpha,
  77. const c10::complex<float> *a, int64_t lda,
  78. const c10::complex<float> *b, int64_t ldb,
  79. c10::complex<float> beta,
  80. c10::complex<float> *c, int64_t ldc);
  81. void gemm(
  82. TransposeType transa, TransposeType transb,
  83. int64_t m, int64_t n, int64_t k,
  84. int64_t alpha,
  85. const int64_t *a, int64_t lda,
  86. const int64_t *b, int64_t ldb,
  87. int64_t beta,
  88. int64_t *c, int64_t ldc);
  89. template <typename scalar_t>
  90. void gemm_batched(
  91. TransposeType transa, TransposeType transb,
  92. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  93. scalar_t alpha,
  94. const scalar_t * const *a, int64_t lda,
  95. const scalar_t * const *b, int64_t ldb,
  96. const scalar_t beta,
  97. scalar_t * const *c, int64_t ldc);
  98. template <typename scalar_t>
  99. void gemm_batched_with_stride(
  100. TransposeType transa, TransposeType transb,
  101. int64_t batch_size, int64_t m, int64_t n, int64_t k,
  102. scalar_t alpha,
  103. const scalar_t *a, int64_t lda, int64_t batch_stride_a,
  104. const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
  105. scalar_t beta,
  106. scalar_t *c, int64_t ldc, int64_t batch_stride_c);
  107. using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
  108. DECLARE_DISPATCH(axpy_fn, axpy_stub);
  109. template<typename scalar_t>
  110. void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
  111. if(n == 1)
  112. {
  113. incx = 1;
  114. incy = 1;
  115. }
  116. axpy_stub(
  117. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  118. n, a, x, incx, y, incy);
  119. }
  120. void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
  121. void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
  122. void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  123. void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  124. using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
  125. DECLARE_DISPATCH(copy_fn, copy_stub);
  126. template<typename scalar_t>
  127. void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
  128. if(n == 1)
  129. {
  130. incx = 1;
  131. incy = 1;
  132. }
  133. copy_stub(
  134. kCPU, c10::CppTypeToScalarType<scalar_t>::value,
  135. n, x, incx, y, incy);
  136. }
  137. void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
  138. void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
  139. void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
  140. void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
  141. }}} // namespace at::native::cpublas