BatchLinearAlgebra.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #pragma once
  2. #include <c10/util/Optional.h>
  3. #include <ATen/Config.h>
  4. #include <ATen/native/DispatchStub.h>
  5. // Forward declare TI
  6. namespace at {
  7. class Tensor;
  8. struct TensorIterator;
  9. namespace native {
  10. enum class TransposeType;
  11. }
  12. }
  13. namespace at { namespace native {
  14. enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
  15. #if AT_BUILD_WITH_LAPACK()
  16. // Define per-batch functions to be used in the implementation of batched
  17. // linear algebra operations
  18. template <class scalar_t>
  19. void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
  20. template <class scalar_t>
  21. void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
  22. template <class scalar_t, class value_t=scalar_t>
  23. void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
  24. template <class scalar_t>
  25. void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
  26. template <class scalar_t>
  27. void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
  28. template <class scalar_t>
  29. void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
  30. template <class scalar_t, class value_t = scalar_t>
  31. void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
  32. template <class scalar_t>
  33. void lapackGels(char trans, int m, int n, int nrhs,
  34. scalar_t *a, int lda, scalar_t *b, int ldb,
  35. scalar_t *work, int lwork, int *info);
  36. template <class scalar_t, class value_t = scalar_t>
  37. void lapackGelsd(int m, int n, int nrhs,
  38. scalar_t *a, int lda, scalar_t *b, int ldb,
  39. value_t *s, value_t rcond, int *rank,
  40. scalar_t* work, int lwork,
  41. value_t *rwork, int* iwork, int *info);
  42. template <class scalar_t, class value_t = scalar_t>
  43. void lapackGelsy(int m, int n, int nrhs,
  44. scalar_t *a, int lda, scalar_t *b, int ldb,
  45. int *jpvt, value_t rcond, int *rank,
  46. scalar_t *work, int lwork, value_t* rwork, int *info);
  47. template <class scalar_t, class value_t = scalar_t>
  48. void lapackGelss(int m, int n, int nrhs,
  49. scalar_t *a, int lda, scalar_t *b, int ldb,
  50. value_t *s, value_t rcond, int *rank,
  51. scalar_t *work, int lwork,
  52. value_t *rwork, int *info);
  53. template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
  54. struct lapackLstsq_impl;
  55. template <class scalar_t, class value_t>
  56. struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
  57. static void call(
  58. char trans, int m, int n, int nrhs,
  59. scalar_t *a, int lda, scalar_t *b, int ldb,
  60. scalar_t *work, int lwork, int *info, // Gels flavor
  61. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  62. value_t *s, // Gelss flavor
  63. int *iwork // Gelsd flavor
  64. ) {
  65. lapackGels<scalar_t>(
  66. trans, m, n, nrhs,
  67. a, lda, b, ldb,
  68. work, lwork, info);
  69. }
  70. };
  71. template <class scalar_t, class value_t>
  72. struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
  73. static void call(
  74. char trans, int m, int n, int nrhs,
  75. scalar_t *a, int lda, scalar_t *b, int ldb,
  76. scalar_t *work, int lwork, int *info, // Gels flavor
  77. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  78. value_t *s, // Gelss flavor
  79. int *iwork // Gelsd flavor
  80. ) {
  81. lapackGelsy<scalar_t, value_t>(
  82. m, n, nrhs,
  83. a, lda, b, ldb,
  84. jpvt, rcond, rank,
  85. work, lwork, rwork, info);
  86. }
  87. };
  88. template <class scalar_t, class value_t>
  89. struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
  90. static void call(
  91. char trans, int m, int n, int nrhs,
  92. scalar_t *a, int lda, scalar_t *b, int ldb,
  93. scalar_t *work, int lwork, int *info, // Gels flavor
  94. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  95. value_t *s, // Gelss flavor
  96. int *iwork // Gelsd flavor
  97. ) {
  98. lapackGelsd<scalar_t, value_t>(
  99. m, n, nrhs,
  100. a, lda, b, ldb,
  101. s, rcond, rank,
  102. work, lwork,
  103. rwork, iwork, info);
  104. }
  105. };
  106. template <class scalar_t, class value_t>
  107. struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
  108. static void call(
  109. char trans, int m, int n, int nrhs,
  110. scalar_t *a, int lda, scalar_t *b, int ldb,
  111. scalar_t *work, int lwork, int *info, // Gels flavor
  112. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  113. value_t *s, // Gelss flavor
  114. int *iwork // Gelsd flavor
  115. ) {
  116. lapackGelss<scalar_t, value_t>(
  117. m, n, nrhs,
  118. a, lda, b, ldb,
  119. s, rcond, rank,
  120. work, lwork,
  121. rwork, info);
  122. }
  123. };
  124. template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
  125. void lapackLstsq(
  126. char trans, int m, int n, int nrhs,
  127. scalar_t *a, int lda, scalar_t *b, int ldb,
  128. scalar_t *work, int lwork, int *info, // Gels flavor
  129. int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
  130. value_t *s, // Gelss flavor
  131. int *iwork // Gelsd flavor
  132. ) {
  133. lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
  134. trans, m, n, nrhs,
  135. a, lda, b, ldb,
  136. work, lwork, info,
  137. jpvt, rcond, rank, rwork,
  138. s,
  139. iwork);
  140. }
  141. template <class scalar_t>
  142. void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
  143. template <class scalar_t>
  144. void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
  145. template <class scalar_t>
  146. void lapackLdlHermitian(
  147. char uplo,
  148. int n,
  149. scalar_t* a,
  150. int lda,
  151. int* ipiv,
  152. scalar_t* work,
  153. int lwork,
  154. int* info);
  155. template <class scalar_t>
  156. void lapackLdlSymmetric(
  157. char uplo,
  158. int n,
  159. scalar_t* a,
  160. int lda,
  161. int* ipiv,
  162. scalar_t* work,
  163. int lwork,
  164. int* info);
  165. template <class scalar_t>
  166. void lapackLdlSolveHermitian(
  167. char uplo,
  168. int n,
  169. int nrhs,
  170. scalar_t* a,
  171. int lda,
  172. int* ipiv,
  173. scalar_t* b,
  174. int ldb,
  175. int* info);
  176. template <class scalar_t>
  177. void lapackLdlSolveSymmetric(
  178. char uplo,
  179. int n,
  180. int nrhs,
  181. scalar_t* a,
  182. int lda,
  183. int* ipiv,
  184. scalar_t* b,
  185. int ldb,
  186. int* info);
  187. template<class scalar_t, class value_t=scalar_t>
  188. void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
  189. #endif
  190. #if AT_BUILD_WITH_BLAS()
  191. template <class scalar_t>
  192. void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
  193. #endif
  194. using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
  195. DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
  196. using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
  197. DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
  198. using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
  199. DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
  200. using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
  201. DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
  202. using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
  203. DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
  204. using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
  205. DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
  206. using linalg_eigh_fn = void (*)(
  207. const Tensor& /*eigenvalues*/,
  208. const Tensor& /*eigenvectors*/,
  209. const Tensor& /*infos*/,
  210. bool /*upper*/,
  211. bool /*compute_eigenvectors*/);
  212. DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
  213. using lstsq_fn = void (*)(
  214. const Tensor& /*a*/,
  215. Tensor& /*b*/,
  216. Tensor& /*rank*/,
  217. Tensor& /*singular_values*/,
  218. Tensor& /*infos*/,
  219. double /*rcond*/,
  220. std::string /*driver_name*/);
  221. DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
  222. using triangular_solve_fn = void (*)(
  223. const Tensor& /*A*/,
  224. const Tensor& /*B*/,
  225. bool /*left*/,
  226. bool /*upper*/,
  227. TransposeType /*transpose*/,
  228. bool /*unitriangular*/);
  229. DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
  230. using lu_factor_fn = void (*)(
  231. const Tensor& /*input*/,
  232. const Tensor& /*pivots*/,
  233. const Tensor& /*infos*/,
  234. bool /*compute_pivots*/);
  235. DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
  236. using unpack_pivots_fn = void(*)(
  237. TensorIterator& iter,
  238. const int64_t dim_size,
  239. const int64_t max_pivot);
  240. DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
  241. using lu_solve_fn = void (*)(
  242. const Tensor& /*LU*/,
  243. const Tensor& /*pivots*/,
  244. const Tensor& /*B*/,
  245. TransposeType /*trans*/);
  246. DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
  247. using ldl_factor_fn = void (*)(
  248. const Tensor& /*LD*/,
  249. const Tensor& /*pivots*/,
  250. const Tensor& /*info*/,
  251. bool /*upper*/,
  252. bool /*hermitian*/);
  253. DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
  254. using svd_fn = void (*)(
  255. const Tensor& /*A*/,
  256. const bool /*full_matrices*/,
  257. const bool /*compute_uv*/,
  258. const c10::optional<c10::string_view>& /*driver*/,
  259. const Tensor& /*U*/,
  260. const Tensor& /*S*/,
  261. const Tensor& /*Vh*/,
  262. const Tensor& /*info*/);
  263. DECLARE_DISPATCH(svd_fn, svd_stub);
  264. using ldl_solve_fn = void (*)(
  265. const Tensor& /*LD*/,
  266. const Tensor& /*pivots*/,
  267. const Tensor& /*result*/,
  268. bool /*upper*/,
  269. bool /*hermitian*/);
  270. DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
  271. }} // namespace at::native