kronecker_product.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
  5. // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
  6. // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
  7. //
  8. // This Source Code Form is subject to the terms of the Mozilla
  9. // Public License v. 2.0. If a copy of the MPL was not distributed
  10. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  11. #ifdef EIGEN_TEST_PART_1
  12. #include "sparse.h"
  13. #include <Eigen/SparseExtra>
  14. #include <Eigen/KroneckerProduct>
  15. template<typename MatrixType>
  16. void check_dimension(const MatrixType& ab, const int rows, const int cols)
  17. {
  18. VERIFY_IS_EQUAL(ab.rows(), rows);
  19. VERIFY_IS_EQUAL(ab.cols(), cols);
  20. }
  21. template<typename MatrixType>
  22. void check_kronecker_product(const MatrixType& ab)
  23. {
  24. VERIFY_IS_EQUAL(ab.rows(), 6);
  25. VERIFY_IS_EQUAL(ab.cols(), 6);
  26. VERIFY_IS_EQUAL(ab.nonZeros(), 36);
  27. VERIFY_IS_APPROX(ab.coeff(0,0), -0.4017367630386106);
  28. VERIFY_IS_APPROX(ab.coeff(0,1), 0.1056863433932735);
  29. VERIFY_IS_APPROX(ab.coeff(0,2), -0.7255206194554212);
  30. VERIFY_IS_APPROX(ab.coeff(0,3), 0.1908653336744706);
  31. VERIFY_IS_APPROX(ab.coeff(0,4), 0.350864567234111);
  32. VERIFY_IS_APPROX(ab.coeff(0,5), -0.0923032108308013);
  33. VERIFY_IS_APPROX(ab.coeff(1,0), 0.415417514804677);
  34. VERIFY_IS_APPROX(ab.coeff(1,1), -0.2369227701722048);
  35. VERIFY_IS_APPROX(ab.coeff(1,2), 0.7502275131458511);
  36. VERIFY_IS_APPROX(ab.coeff(1,3), -0.4278731019742696);
  37. VERIFY_IS_APPROX(ab.coeff(1,4), -0.3628129162264507);
  38. VERIFY_IS_APPROX(ab.coeff(1,5), 0.2069210808481275);
  39. VERIFY_IS_APPROX(ab.coeff(2,0), 0.05465890160863986);
  40. VERIFY_IS_APPROX(ab.coeff(2,1), -0.2634092511419858);
  41. VERIFY_IS_APPROX(ab.coeff(2,2), 0.09871180285793758);
  42. VERIFY_IS_APPROX(ab.coeff(2,3), -0.4757066334017702);
  43. VERIFY_IS_APPROX(ab.coeff(2,4), -0.04773740823058334);
  44. VERIFY_IS_APPROX(ab.coeff(2,5), 0.2300535609645254);
  45. VERIFY_IS_APPROX(ab.coeff(3,0), -0.8172945853260133);
  46. VERIFY_IS_APPROX(ab.coeff(3,1), 0.2150086428359221);
  47. VERIFY_IS_APPROX(ab.coeff(3,2), 0.5825113847292743);
  48. VERIFY_IS_APPROX(ab.coeff(3,3), -0.1532433770097174);
  49. VERIFY_IS_APPROX(ab.coeff(3,4), -0.329383387282399);
  50. VERIFY_IS_APPROX(ab.coeff(3,5), 0.08665207912033064);
  51. VERIFY_IS_APPROX(ab.coeff(4,0), 0.8451267514863225);
  52. VERIFY_IS_APPROX(ab.coeff(4,1), -0.481996458918977);
  53. VERIFY_IS_APPROX(ab.coeff(4,2), -0.6023482390791535);
  54. VERIFY_IS_APPROX(ab.coeff(4,3), 0.3435339347164565);
  55. VERIFY_IS_APPROX(ab.coeff(4,4), 0.3406002157428891);
  56. VERIFY_IS_APPROX(ab.coeff(4,5), -0.1942526344200915);
  57. VERIFY_IS_APPROX(ab.coeff(5,0), 0.1111982482925399);
  58. VERIFY_IS_APPROX(ab.coeff(5,1), -0.5358806424754169);
  59. VERIFY_IS_APPROX(ab.coeff(5,2), -0.07925446559335647);
  60. VERIFY_IS_APPROX(ab.coeff(5,3), 0.3819388757769038);
  61. VERIFY_IS_APPROX(ab.coeff(5,4), 0.04481475387219876);
  62. VERIFY_IS_APPROX(ab.coeff(5,5), -0.2159688616158057);
  63. }
  64. template<typename MatrixType>
  65. void check_sparse_kronecker_product(const MatrixType& ab)
  66. {
  67. VERIFY_IS_EQUAL(ab.rows(), 12);
  68. VERIFY_IS_EQUAL(ab.cols(), 10);
  69. VERIFY_IS_EQUAL(ab.nonZeros(), 3*2);
  70. VERIFY_IS_APPROX(ab.coeff(3,0), -0.04);
  71. VERIFY_IS_APPROX(ab.coeff(5,1), 0.05);
  72. VERIFY_IS_APPROX(ab.coeff(0,6), -0.08);
  73. VERIFY_IS_APPROX(ab.coeff(2,7), 0.10);
  74. VERIFY_IS_APPROX(ab.coeff(6,8), 0.12);
  75. VERIFY_IS_APPROX(ab.coeff(8,9), -0.15);
  76. }
  77. EIGEN_DECLARE_TEST(kronecker_product)
  78. {
  79. // DM = dense matrix; SM = sparse matrix
  80. Matrix<double, 2, 3> DM_a;
  81. SparseMatrix<double> SM_a(2,3);
  82. SM_a.insert(0,0) = DM_a.coeffRef(0,0) = -0.4461540300782201;
  83. SM_a.insert(0,1) = DM_a.coeffRef(0,1) = -0.8057364375283049;
  84. SM_a.insert(0,2) = DM_a.coeffRef(0,2) = 0.3896572459516341;
  85. SM_a.insert(1,0) = DM_a.coeffRef(1,0) = -0.9076572187376921;
  86. SM_a.insert(1,1) = DM_a.coeffRef(1,1) = 0.6469156566545853;
  87. SM_a.insert(1,2) = DM_a.coeffRef(1,2) = -0.3658010398782789;
  88. MatrixXd DM_b(3,2);
  89. SparseMatrix<double> SM_b(3,2);
  90. SM_b.insert(0,0) = DM_b.coeffRef(0,0) = 0.9004440976767099;
  91. SM_b.insert(0,1) = DM_b.coeffRef(0,1) = -0.2368830858139832;
  92. SM_b.insert(1,0) = DM_b.coeffRef(1,0) = -0.9311078389941825;
  93. SM_b.insert(1,1) = DM_b.coeffRef(1,1) = 0.5310335762980047;
  94. SM_b.insert(2,0) = DM_b.coeffRef(2,0) = -0.1225112806872035;
  95. SM_b.insert(2,1) = DM_b.coeffRef(2,1) = 0.5903998022741264;
  96. SparseMatrix<double,RowMajor> SM_row_a(SM_a), SM_row_b(SM_b);
  97. // test DM_fixedSize = kroneckerProduct(DM_block,DM)
  98. Matrix<double, 6, 6> DM_fix_ab = kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b);
  99. CALL_SUBTEST(check_kronecker_product(DM_fix_ab));
  100. CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a.topLeftCorner<2,3>(),DM_b)));
  101. for(int i=0;i<DM_fix_ab.rows();++i)
  102. for(int j=0;j<DM_fix_ab.cols();++j)
  103. VERIFY_IS_APPROX(kroneckerProduct(DM_a,DM_b).coeff(i,j), DM_fix_ab(i,j));
  104. // test DM_block = kroneckerProduct(DM,DM)
  105. MatrixXd DM_block_ab(10,15);
  106. DM_block_ab.block<6,6>(2,5) = kroneckerProduct(DM_a,DM_b);
  107. CALL_SUBTEST(check_kronecker_product(DM_block_ab.block<6,6>(2,5)));
  108. // test DM = kroneckerProduct(DM,DM)
  109. MatrixXd DM_ab = kroneckerProduct(DM_a,DM_b);
  110. CALL_SUBTEST(check_kronecker_product(DM_ab));
  111. CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,DM_b)));
  112. // test SM = kroneckerProduct(SM,DM)
  113. SparseMatrix<double> SM_ab = kroneckerProduct(SM_a,DM_b);
  114. CALL_SUBTEST(check_kronecker_product(SM_ab));
  115. SparseMatrix<double,RowMajor> SM_ab2 = kroneckerProduct(SM_a,DM_b);
  116. CALL_SUBTEST(check_kronecker_product(SM_ab2));
  117. CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,DM_b)));
  118. // test SM = kroneckerProduct(DM,SM)
  119. SM_ab.setZero();
  120. SM_ab.insert(0,0)=37.0;
  121. SM_ab = kroneckerProduct(DM_a,SM_b);
  122. CALL_SUBTEST(check_kronecker_product(SM_ab));
  123. SM_ab2.setZero();
  124. SM_ab2.insert(0,0)=37.0;
  125. SM_ab2 = kroneckerProduct(DM_a,SM_b);
  126. CALL_SUBTEST(check_kronecker_product(SM_ab2));
  127. CALL_SUBTEST(check_kronecker_product(kroneckerProduct(DM_a,SM_b)));
  128. // test SM = kroneckerProduct(SM,SM)
  129. SM_ab.resize(2,33);
  130. SM_ab.insert(0,0)=37.0;
  131. SM_ab = kroneckerProduct(SM_a,SM_b);
  132. CALL_SUBTEST(check_kronecker_product(SM_ab));
  133. SM_ab2.resize(5,11);
  134. SM_ab2.insert(0,0)=37.0;
  135. SM_ab2 = kroneckerProduct(SM_a,SM_b);
  136. CALL_SUBTEST(check_kronecker_product(SM_ab2));
  137. CALL_SUBTEST(check_kronecker_product(kroneckerProduct(SM_a,SM_b)));
  138. // test SM = kroneckerProduct(SM,SM) with sparse pattern
  139. SM_a.resize(4,5);
  140. SM_b.resize(3,2);
  141. SM_a.resizeNonZeros(0);
  142. SM_b.resizeNonZeros(0);
  143. SM_a.insert(1,0) = -0.1;
  144. SM_a.insert(0,3) = -0.2;
  145. SM_a.insert(2,4) = 0.3;
  146. SM_a.finalize();
  147. SM_b.insert(0,0) = 0.4;
  148. SM_b.insert(2,1) = -0.5;
  149. SM_b.finalize();
  150. SM_ab.resize(1,1);
  151. SM_ab.insert(0,0)=37.0;
  152. SM_ab = kroneckerProduct(SM_a,SM_b);
  153. CALL_SUBTEST(check_sparse_kronecker_product(SM_ab));
  154. // test dimension of result of DM = kroneckerProduct(DM,DM)
  155. MatrixXd DM_a2(2,1);
  156. MatrixXd DM_b2(5,4);
  157. MatrixXd DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
  158. CALL_SUBTEST(check_dimension(DM_ab2,2*5,1*4));
  159. DM_a2.resize(10,9);
  160. DM_b2.resize(4,8);
  161. DM_ab2 = kroneckerProduct(DM_a2,DM_b2);
  162. CALL_SUBTEST(check_dimension(DM_ab2,10*4,9*8));
  163. for(int i = 0; i < g_repeat; i++)
  164. {
  165. double density = Eigen::internal::random<double>(0.01,0.5);
  166. int ra = Eigen::internal::random<int>(1,50);
  167. int ca = Eigen::internal::random<int>(1,50);
  168. int rb = Eigen::internal::random<int>(1,50);
  169. int cb = Eigen::internal::random<int>(1,50);
  170. SparseMatrix<float,ColMajor> sA(ra,ca), sB(rb,cb), sC;
  171. SparseMatrix<float,RowMajor> sC2;
  172. MatrixXf dA(ra,ca), dB(rb,cb), dC;
  173. initSparse(density, dA, sA);
  174. initSparse(density, dB, sB);
  175. sC = kroneckerProduct(sA,sB);
  176. dC = kroneckerProduct(dA,dB);
  177. VERIFY_IS_APPROX(MatrixXf(sC),dC);
  178. sC = kroneckerProduct(sA.transpose(),sB);
  179. dC = kroneckerProduct(dA.transpose(),dB);
  180. VERIFY_IS_APPROX(MatrixXf(sC),dC);
  181. sC = kroneckerProduct(sA.transpose(),sB.transpose());
  182. dC = kroneckerProduct(dA.transpose(),dB.transpose());
  183. VERIFY_IS_APPROX(MatrixXf(sC),dC);
  184. sC = kroneckerProduct(sA,sB.transpose());
  185. dC = kroneckerProduct(dA,dB.transpose());
  186. VERIFY_IS_APPROX(MatrixXf(sC),dC);
  187. sC2 = kroneckerProduct(sA,sB);
  188. dC = kroneckerProduct(dA,dB);
  189. VERIFY_IS_APPROX(MatrixXf(sC2),dC);
  190. sC2 = kroneckerProduct(dA,sB);
  191. dC = kroneckerProduct(dA,dB);
  192. VERIFY_IS_APPROX(MatrixXf(sC2),dC);
  193. sC2 = kroneckerProduct(sA,dB);
  194. dC = kroneckerProduct(dA,dB);
  195. VERIFY_IS_APPROX(MatrixXf(sC2),dC);
  196. sC2 = kroneckerProduct(2*sA,sB);
  197. dC = kroneckerProduct(2*dA,dB);
  198. VERIFY_IS_APPROX(MatrixXf(sC2),dC);
  199. }
  200. }
  201. #endif
  202. #ifdef EIGEN_TEST_PART_2
  203. // simply check that for a dense kronecker product, sparse module is not needed
  204. #include "main.h"
  205. #include <Eigen/KroneckerProduct>
  206. EIGEN_DECLARE_TEST(kronecker_product)
  207. {
  208. MatrixXd a(2,2), b(3,3), c;
  209. a.setRandom();
  210. b.setRandom();
  211. c = kroneckerProduct(a,b);
  212. VERIFY_IS_APPROX(c.block(3,3,3,3), a(1,1)*b);
  213. }
  214. #endif