small_blas_generic.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: yangfan34@lenovo.com (Lenovo Research Device+ Lab - Shanghai)
  30. //
  31. // Optimization for simple blas functions used in the Schur Eliminator.
  32. // These are fairly basic implementations which already yield a significant
  33. // speedup in the eliminator performance.
  34. #ifndef CERES_INTERNAL_SMALL_BLAS_GENERIC_H_
  35. #define CERES_INTERNAL_SMALL_BLAS_GENERIC_H_
  36. namespace ceres::internal {
  37. // The following macros are used to share code
  38. #define CERES_GEMM_OPT_NAIVE_HEADER \
  39. double cvec4[4] = {0.0, 0.0, 0.0, 0.0}; \
  40. const double* pa = a; \
  41. const double* pb = b; \
  42. const int span = 4; \
  43. int col_r = col_a & (span - 1); \
  44. int col_m = col_a - col_r;
  45. #define CERES_GEMM_OPT_STORE_MAT1X4 \
  46. if (kOperation > 0) { \
  47. c[0] += cvec4[0]; \
  48. c[1] += cvec4[1]; \
  49. c[2] += cvec4[2]; \
  50. c[3] += cvec4[3]; \
  51. } else if (kOperation < 0) { \
  52. c[0] -= cvec4[0]; \
  53. c[1] -= cvec4[1]; \
  54. c[2] -= cvec4[2]; \
  55. c[3] -= cvec4[3]; \
  56. } else { \
  57. c[0] = cvec4[0]; \
  58. c[1] = cvec4[1]; \
  59. c[2] = cvec4[2]; \
  60. c[3] = cvec4[3]; \
  61. } \
  62. c += 4;
  63. // Matrix-Matrix Multiplication
  64. // Figure out 1x4 of Matrix C in one batch
  65. //
  66. // c op a * B;
  67. // where op can be +=, -=, or =, indicated by kOperation.
  68. //
  69. // Matrix C Matrix A Matrix B
  70. //
  71. // C0, C1, C2, C3 op A0, A1, A2, A3, ... * B0, B1, B2, B3
  72. // B4, B5, B6, B7
  73. // B8, B9, Ba, Bb
  74. // Bc, Bd, Be, Bf
  75. // . , . , . , .
  76. // . , . , . , .
  77. // . , . , . , .
  78. //
  79. // unroll for loops
  80. // utilize the data resided in cache
  81. // NOTE: col_a means the columns of A
  82. static inline void MMM_mat1x4(const int col_a,
  83. const double* a,
  84. const double* b,
  85. const int col_stride_b,
  86. double* c,
  87. const int kOperation) {
  88. CERES_GEMM_OPT_NAIVE_HEADER
  89. double av = 0.0;
  90. int bi = 0;
  91. #define CERES_GEMM_OPT_MMM_MAT1X4_MUL \
  92. av = pa[k]; \
  93. pb = b + bi; \
  94. cvec4[0] += av * pb[0]; \
  95. cvec4[1] += av * pb[1]; \
  96. cvec4[2] += av * pb[2]; \
  97. cvec4[3] += av * pb[3]; \
  98. pb += 4; \
  99. bi += col_stride_b; \
  100. k++;
  101. for (int k = 0; k < col_m;) {
  102. CERES_GEMM_OPT_MMM_MAT1X4_MUL
  103. CERES_GEMM_OPT_MMM_MAT1X4_MUL
  104. CERES_GEMM_OPT_MMM_MAT1X4_MUL
  105. CERES_GEMM_OPT_MMM_MAT1X4_MUL
  106. }
  107. for (int k = col_m; k < col_a;) {
  108. CERES_GEMM_OPT_MMM_MAT1X4_MUL
  109. }
  110. CERES_GEMM_OPT_STORE_MAT1X4
  111. #undef CERES_GEMM_OPT_MMM_MAT1X4_MUL
  112. }
  113. // Matrix Transpose-Matrix multiplication
  114. // Figure out 1x4 of Matrix C in one batch
  115. //
  116. // c op a' * B;
  117. // where op can be +=, -=, or = indicated by kOperation.
  118. //
  119. // Matrix A
  120. //
  121. // A0
  122. // A1
  123. // A2
  124. // A3
  125. // .
  126. // .
  127. // .
  128. //
  129. // Matrix C Matrix A' Matrix B
  130. //
  131. // C0, C1, C2, C3 op A0, A1, A2, A3, ... * B0, B1, B2, B3
  132. // B4, B5, B6, B7
  133. // B8, B9, Ba, Bb
  134. // Bc, Bd, Be, Bf
  135. // . , . , . , .
  136. // . , . , . , .
  137. // . , . , . , .
  138. //
  139. // unroll for loops
  140. // utilize the data resided in cache
  141. // NOTE: col_a means the columns of A'
  142. static inline void MTM_mat1x4(const int col_a,
  143. const double* a,
  144. const int col_stride_a,
  145. const double* b,
  146. const int col_stride_b,
  147. double* c,
  148. const int kOperation) {
  149. CERES_GEMM_OPT_NAIVE_HEADER
  150. double av = 0.0;
  151. int ai = 0;
  152. int bi = 0;
  153. #define CERES_GEMM_OPT_MTM_MAT1X4_MUL \
  154. av = pa[ai]; \
  155. pb = b + bi; \
  156. cvec4[0] += av * pb[0]; \
  157. cvec4[1] += av * pb[1]; \
  158. cvec4[2] += av * pb[2]; \
  159. cvec4[3] += av * pb[3]; \
  160. pb += 4; \
  161. ai += col_stride_a; \
  162. bi += col_stride_b;
  163. for (int k = 0; k < col_m; k += span) {
  164. CERES_GEMM_OPT_MTM_MAT1X4_MUL
  165. CERES_GEMM_OPT_MTM_MAT1X4_MUL
  166. CERES_GEMM_OPT_MTM_MAT1X4_MUL
  167. CERES_GEMM_OPT_MTM_MAT1X4_MUL
  168. }
  169. for (int k = col_m; k < col_a; k++) {
  170. CERES_GEMM_OPT_MTM_MAT1X4_MUL
  171. }
  172. CERES_GEMM_OPT_STORE_MAT1X4
  173. #undef CERES_GEMM_OPT_MTM_MAT1X4_MUL
  174. }
  175. // Matrix-Vector Multiplication
  176. // Figure out 4x1 of vector c in one batch
  177. //
  178. // c op A * b;
  179. // where op can be +=, -=, or =, indicated by kOperation.
  180. //
  181. // Vector c Matrix A Vector b
  182. //
  183. // C0 op A0, A1, A2, A3, ... * B0
  184. // C1 A4, A5, A6, A7, ... B1
  185. // C2 A8, A9, Aa, Ab, ... B2
  186. // C3 Ac, Ad, Ae, Af, ... B3
  187. // .
  188. // .
  189. // .
  190. //
  191. // unroll for loops
  192. // utilize the data resided in cache
  193. // NOTE: col_a means the columns of A
  194. static inline void MVM_mat4x1(const int col_a,
  195. const double* a,
  196. const int col_stride_a,
  197. const double* b,
  198. double* c,
  199. const int kOperation) {
  200. CERES_GEMM_OPT_NAIVE_HEADER
  201. double bv = 0.0;
  202. // clang-format off
  203. #define CERES_GEMM_OPT_MVM_MAT4X1_MUL \
  204. bv = *pb; \
  205. cvec4[0] += *(pa ) * bv; \
  206. cvec4[1] += *(pa + col_stride_a ) * bv; \
  207. cvec4[2] += *(pa + col_stride_a * 2) * bv; \
  208. cvec4[3] += *(pa + col_stride_a * 3) * bv; \
  209. pa++; \
  210. pb++;
  211. // clang-format on
  212. for (int k = 0; k < col_m; k += span) {
  213. CERES_GEMM_OPT_MVM_MAT4X1_MUL
  214. CERES_GEMM_OPT_MVM_MAT4X1_MUL
  215. CERES_GEMM_OPT_MVM_MAT4X1_MUL
  216. CERES_GEMM_OPT_MVM_MAT4X1_MUL
  217. }
  218. for (int k = col_m; k < col_a; k++) {
  219. CERES_GEMM_OPT_MVM_MAT4X1_MUL
  220. }
  221. CERES_GEMM_OPT_STORE_MAT1X4
  222. #undef CERES_GEMM_OPT_MVM_MAT4X1_MUL
  223. }
  224. // Matrix Transpose-Vector multiplication
  225. // Figure out 4x1 of vector c in one batch
  226. //
  227. // c op A' * b;
  228. // where op can be +=, -=, or =, indicated by kOperation.
  229. //
  230. // Matrix A
  231. //
  232. // A0, A4, A8, Ac
  233. // A1, A5, A9, Ad
  234. // A2, A6, Aa, Ae
  235. // A3, A7, Ab, Af
  236. // . , . , . , .
  237. // . , . , . , .
  238. // . , . , . , .
  239. //
  240. // Vector c Matrix A' Vector b
  241. //
  242. // C0 op A0, A1, A2, A3, ... * B0
  243. // C1 A4, A5, A6, A7, ... B1
  244. // C2 A8, A9, Aa, Ab, ... B2
  245. // C3 Ac, Ad, Ae, Af, ... B3
  246. // .
  247. // .
  248. // .
  249. //
  250. // unroll for loops
  251. // utilize the data resided in cache
  252. // NOTE: col_a means the columns of A'
  253. static inline void MTV_mat4x1(const int col_a,
  254. const double* a,
  255. const int col_stride_a,
  256. const double* b,
  257. double* c,
  258. const int kOperation) {
  259. CERES_GEMM_OPT_NAIVE_HEADER
  260. double bv = 0.0;
  261. #define CERES_GEMM_OPT_MTV_MAT4X1_MUL \
  262. bv = *pb; \
  263. cvec4[0] += pa[0] * bv; \
  264. cvec4[1] += pa[1] * bv; \
  265. cvec4[2] += pa[2] * bv; \
  266. cvec4[3] += pa[3] * bv; \
  267. pa += col_stride_a; \
  268. pb++;
  269. for (int k = 0; k < col_m; k += span) {
  270. CERES_GEMM_OPT_MTV_MAT4X1_MUL
  271. CERES_GEMM_OPT_MTV_MAT4X1_MUL
  272. CERES_GEMM_OPT_MTV_MAT4X1_MUL
  273. CERES_GEMM_OPT_MTV_MAT4X1_MUL
  274. }
  275. for (int k = col_m; k < col_a; k++) {
  276. CERES_GEMM_OPT_MTV_MAT4X1_MUL
  277. }
  278. CERES_GEMM_OPT_STORE_MAT1X4
  279. #undef CERES_GEMM_OPT_MTV_MAT4X1_MUL
  280. }
  281. #undef CERES_GEMM_OPT_NAIVE_HEADER
  282. #undef CERES_GEMM_OPT_STORE_MAT1X4
  283. } // namespace ceres::internal
  284. #endif // CERES_INTERNAL_SMALL_BLAS_GENERIC_H_