cxx11_tensor_contraction.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::DefaultDevice;
  12. using Eigen::Tensor;
  13. typedef Tensor<float, 1>::DimensionPair DimPair;
  14. template<int DataLayout>
  15. static void test_evals()
  16. {
  17. Tensor<float, 2, DataLayout> mat1(2, 3);
  18. Tensor<float, 2, DataLayout> mat2(2, 3);
  19. Tensor<float, 2, DataLayout> mat3(3, 2);
  20. mat1.setRandom();
  21. mat2.setRandom();
  22. mat3.setRandom();
  23. Tensor<float, 2, DataLayout> mat4(3,3);
  24. mat4.setZero();
  25. Eigen::array<DimPair, 1> dims3 = {{DimPair(0, 0)}};
  26. typedef TensorEvaluator<decltype(mat1.contract(mat2, dims3)), DefaultDevice> Evaluator;
  27. Evaluator eval(mat1.contract(mat2, dims3), DefaultDevice());
  28. eval.evalTo(mat4.data());
  29. EIGEN_STATIC_ASSERT(Evaluator::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
  30. VERIFY_IS_EQUAL(eval.dimensions()[0], 3);
  31. VERIFY_IS_EQUAL(eval.dimensions()[1], 3);
  32. VERIFY_IS_APPROX(mat4(0,0), mat1(0,0)*mat2(0,0) + mat1(1,0)*mat2(1,0));
  33. VERIFY_IS_APPROX(mat4(0,1), mat1(0,0)*mat2(0,1) + mat1(1,0)*mat2(1,1));
  34. VERIFY_IS_APPROX(mat4(0,2), mat1(0,0)*mat2(0,2) + mat1(1,0)*mat2(1,2));
  35. VERIFY_IS_APPROX(mat4(1,0), mat1(0,1)*mat2(0,0) + mat1(1,1)*mat2(1,0));
  36. VERIFY_IS_APPROX(mat4(1,1), mat1(0,1)*mat2(0,1) + mat1(1,1)*mat2(1,1));
  37. VERIFY_IS_APPROX(mat4(1,2), mat1(0,1)*mat2(0,2) + mat1(1,1)*mat2(1,2));
  38. VERIFY_IS_APPROX(mat4(2,0), mat1(0,2)*mat2(0,0) + mat1(1,2)*mat2(1,0));
  39. VERIFY_IS_APPROX(mat4(2,1), mat1(0,2)*mat2(0,1) + mat1(1,2)*mat2(1,1));
  40. VERIFY_IS_APPROX(mat4(2,2), mat1(0,2)*mat2(0,2) + mat1(1,2)*mat2(1,2));
  41. Tensor<float, 2, DataLayout> mat5(2,2);
  42. mat5.setZero();
  43. Eigen::array<DimPair, 1> dims4 = {{DimPair(1, 1)}};
  44. typedef TensorEvaluator<decltype(mat1.contract(mat2, dims4)), DefaultDevice> Evaluator2;
  45. Evaluator2 eval2(mat1.contract(mat2, dims4), DefaultDevice());
  46. eval2.evalTo(mat5.data());
  47. EIGEN_STATIC_ASSERT(Evaluator2::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
  48. VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
  49. VERIFY_IS_EQUAL(eval2.dimensions()[1], 2);
  50. VERIFY_IS_APPROX(mat5(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(0,1) + mat1(0,2)*mat2(0,2));
  51. VERIFY_IS_APPROX(mat5(0,1), mat1(0,0)*mat2(1,0) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(1,2));
  52. VERIFY_IS_APPROX(mat5(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(0,1) + mat1(1,2)*mat2(0,2));
  53. VERIFY_IS_APPROX(mat5(1,1), mat1(1,0)*mat2(1,0) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(1,2));
  54. Tensor<float, 2, DataLayout> mat6(2,2);
  55. mat6.setZero();
  56. Eigen::array<DimPair, 1> dims6 = {{DimPair(1, 0)}};
  57. typedef TensorEvaluator<decltype(mat1.contract(mat3, dims6)), DefaultDevice> Evaluator3;
  58. Evaluator3 eval3(mat1.contract(mat3, dims6), DefaultDevice());
  59. eval3.evalTo(mat6.data());
  60. EIGEN_STATIC_ASSERT(Evaluator3::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
  61. VERIFY_IS_EQUAL(eval3.dimensions()[0], 2);
  62. VERIFY_IS_EQUAL(eval3.dimensions()[1], 2);
  63. VERIFY_IS_APPROX(mat6(0,0), mat1(0,0)*mat3(0,0) + mat1(0,1)*mat3(1,0) + mat1(0,2)*mat3(2,0));
  64. VERIFY_IS_APPROX(mat6(0,1), mat1(0,0)*mat3(0,1) + mat1(0,1)*mat3(1,1) + mat1(0,2)*mat3(2,1));
  65. VERIFY_IS_APPROX(mat6(1,0), mat1(1,0)*mat3(0,0) + mat1(1,1)*mat3(1,0) + mat1(1,2)*mat3(2,0));
  66. VERIFY_IS_APPROX(mat6(1,1), mat1(1,0)*mat3(0,1) + mat1(1,1)*mat3(1,1) + mat1(1,2)*mat3(2,1));
  67. }
  68. template<int DataLayout>
  69. static void test_scalar()
  70. {
  71. Tensor<float, 1, DataLayout> vec1({6});
  72. Tensor<float, 1, DataLayout> vec2({6});
  73. vec1.setRandom();
  74. vec2.setRandom();
  75. Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}};
  76. Tensor<float, 0, DataLayout> scalar = vec1.contract(vec2, dims);
  77. float expected = 0.0f;
  78. for (int i = 0; i < 6; ++i) {
  79. expected += vec1(i) * vec2(i);
  80. }
  81. VERIFY_IS_APPROX(scalar(), expected);
  82. }
  83. template<int DataLayout>
  84. static void test_multidims()
  85. {
  86. Tensor<float, 3, DataLayout> mat1(2, 2, 2);
  87. Tensor<float, 4, DataLayout> mat2(2, 2, 2, 2);
  88. mat1.setRandom();
  89. mat2.setRandom();
  90. Tensor<float, 3, DataLayout> mat3(2, 2, 2);
  91. mat3.setZero();
  92. Eigen::array<DimPair, 2> dims = {{DimPair(1, 2), DimPair(2, 3)}};
  93. typedef TensorEvaluator<decltype(mat1.contract(mat2, dims)), DefaultDevice> Evaluator;
  94. Evaluator eval(mat1.contract(mat2, dims), DefaultDevice());
  95. eval.evalTo(mat3.data());
  96. EIGEN_STATIC_ASSERT(Evaluator::NumDims==3ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
  97. VERIFY_IS_EQUAL(eval.dimensions()[0], 2);
  98. VERIFY_IS_EQUAL(eval.dimensions()[1], 2);
  99. VERIFY_IS_EQUAL(eval.dimensions()[2], 2);
  100. VERIFY_IS_APPROX(mat3(0,0,0), mat1(0,0,0)*mat2(0,0,0,0) + mat1(0,1,0)*mat2(0,0,1,0) +
  101. mat1(0,0,1)*mat2(0,0,0,1) + mat1(0,1,1)*mat2(0,0,1,1));
  102. VERIFY_IS_APPROX(mat3(0,0,1), mat1(0,0,0)*mat2(0,1,0,0) + mat1(0,1,0)*mat2(0,1,1,0) +
  103. mat1(0,0,1)*mat2(0,1,0,1) + mat1(0,1,1)*mat2(0,1,1,1));
  104. VERIFY_IS_APPROX(mat3(0,1,0), mat1(0,0,0)*mat2(1,0,0,0) + mat1(0,1,0)*mat2(1,0,1,0) +
  105. mat1(0,0,1)*mat2(1,0,0,1) + mat1(0,1,1)*mat2(1,0,1,1));
  106. VERIFY_IS_APPROX(mat3(0,1,1), mat1(0,0,0)*mat2(1,1,0,0) + mat1(0,1,0)*mat2(1,1,1,0) +
  107. mat1(0,0,1)*mat2(1,1,0,1) + mat1(0,1,1)*mat2(1,1,1,1));
  108. VERIFY_IS_APPROX(mat3(1,0,0), mat1(1,0,0)*mat2(0,0,0,0) + mat1(1,1,0)*mat2(0,0,1,0) +
  109. mat1(1,0,1)*mat2(0,0,0,1) + mat1(1,1,1)*mat2(0,0,1,1));
  110. VERIFY_IS_APPROX(mat3(1,0,1), mat1(1,0,0)*mat2(0,1,0,0) + mat1(1,1,0)*mat2(0,1,1,0) +
  111. mat1(1,0,1)*mat2(0,1,0,1) + mat1(1,1,1)*mat2(0,1,1,1));
  112. VERIFY_IS_APPROX(mat3(1,1,0), mat1(1,0,0)*mat2(1,0,0,0) + mat1(1,1,0)*mat2(1,0,1,0) +
  113. mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1));
  114. VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) +
  115. mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1));
  116. Tensor<float, 2, DataLayout> mat4(2, 2);
  117. Tensor<float, 3, DataLayout> mat5(2, 2, 2);
  118. mat4.setRandom();
  119. mat5.setRandom();
  120. Tensor<float, 1, DataLayout> mat6(2);
  121. mat6.setZero();
  122. Eigen::array<DimPair, 2> dims2({{DimPair(0, 1), DimPair(1, 0)}});
  123. typedef TensorEvaluator<decltype(mat4.contract(mat5, dims2)), DefaultDevice> Evaluator2;
  124. Evaluator2 eval2(mat4.contract(mat5, dims2), DefaultDevice());
  125. eval2.evalTo(mat6.data());
  126. EIGEN_STATIC_ASSERT(Evaluator2::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
  127. VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
  128. VERIFY_IS_APPROX(mat6(0), mat4(0,0)*mat5(0,0,0) + mat4(1,0)*mat5(0,1,0) +
  129. mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0));
  130. VERIFY_IS_APPROX(mat6(1), mat4(0,0)*mat5(0,0,1) + mat4(1,0)*mat5(0,1,1) +
  131. mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1));
  132. }
  133. template<int DataLayout>
  134. static void test_holes() {
  135. Tensor<float, 4, DataLayout> t1(2, 5, 7, 3);
  136. Tensor<float, 5, DataLayout> t2(2, 7, 11, 13, 3);
  137. t1.setRandom();
  138. t2.setRandom();
  139. Eigen::array<DimPair, 2> dims = {{DimPair(0, 0), DimPair(3, 4)}};
  140. Tensor<float, 5, DataLayout> result = t1.contract(t2, dims);
  141. VERIFY_IS_EQUAL(result.dimension(0), 5);
  142. VERIFY_IS_EQUAL(result.dimension(1), 7);
  143. VERIFY_IS_EQUAL(result.dimension(2), 7);
  144. VERIFY_IS_EQUAL(result.dimension(3), 11);
  145. VERIFY_IS_EQUAL(result.dimension(4), 13);
  146. for (int i = 0; i < 5; ++i) {
  147. for (int j = 0; j < 5; ++j) {
  148. for (int k = 0; k < 5; ++k) {
  149. for (int l = 0; l < 5; ++l) {
  150. for (int m = 0; m < 5; ++m) {
  151. VERIFY_IS_APPROX(result(i, j, k, l, m),
  152. t1(0, i, j, 0) * t2(0, k, l, m, 0) +
  153. t1(1, i, j, 0) * t2(1, k, l, m, 0) +
  154. t1(0, i, j, 1) * t2(0, k, l, m, 1) +
  155. t1(1, i, j, 1) * t2(1, k, l, m, 1) +
  156. t1(0, i, j, 2) * t2(0, k, l, m, 2) +
  157. t1(1, i, j, 2) * t2(1, k, l, m, 2));
  158. }
  159. }
  160. }
  161. }
  162. }
  163. }
  164. template<int DataLayout>
  165. static void test_full_redux()
  166. {
  167. Tensor<float, 2, DataLayout> t1(2, 2);
  168. Tensor<float, 3, DataLayout> t2(2, 2, 2);
  169. t1.setRandom();
  170. t2.setRandom();
  171. Eigen::array<DimPair, 2> dims = {{DimPair(0, 0), DimPair(1, 1)}};
  172. Tensor<float, 1, DataLayout> result = t1.contract(t2, dims);
  173. VERIFY_IS_EQUAL(result.dimension(0), 2);
  174. VERIFY_IS_APPROX(result(0), t1(0, 0) * t2(0, 0, 0) + t1(1, 0) * t2(1, 0, 0)
  175. + t1(0, 1) * t2(0, 1, 0) + t1(1, 1) * t2(1, 1, 0));
  176. VERIFY_IS_APPROX(result(1), t1(0, 0) * t2(0, 0, 1) + t1(1, 0) * t2(1, 0, 1)
  177. + t1(0, 1) * t2(0, 1, 1) + t1(1, 1) * t2(1, 1, 1));
  178. dims[0] = DimPair(1, 0);
  179. dims[1] = DimPair(2, 1);
  180. result = t2.contract(t1, dims);
  181. VERIFY_IS_EQUAL(result.dimension(0), 2);
  182. VERIFY_IS_APPROX(result(0), t1(0, 0) * t2(0, 0, 0) + t1(1, 0) * t2(0, 1, 0)
  183. + t1(0, 1) * t2(0, 0, 1) + t1(1, 1) * t2(0, 1, 1));
  184. VERIFY_IS_APPROX(result(1), t1(0, 0) * t2(1, 0, 0) + t1(1, 0) * t2(1, 1, 0)
  185. + t1(0, 1) * t2(1, 0, 1) + t1(1, 1) * t2(1, 1, 1));
  186. }
  187. template<int DataLayout>
  188. static void test_contraction_of_contraction()
  189. {
  190. Tensor<float, 2, DataLayout> t1(2, 2);
  191. Tensor<float, 2, DataLayout> t2(2, 2);
  192. Tensor<float, 2, DataLayout> t3(2, 2);
  193. Tensor<float, 2, DataLayout> t4(2, 2);
  194. t1.setRandom();
  195. t2.setRandom();
  196. t3.setRandom();
  197. t4.setRandom();
  198. Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
  199. auto contract1 = t1.contract(t2, dims);
  200. auto diff = t3 - contract1;
  201. auto contract2 = t1.contract(t4, dims);
  202. Tensor<float, 2, DataLayout> result = contract2.contract(diff, dims);
  203. VERIFY_IS_EQUAL(result.dimension(0), 2);
  204. VERIFY_IS_EQUAL(result.dimension(1), 2);
  205. Eigen::Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>>
  206. m1(t1.data(), 2, 2), m2(t2.data(), 2, 2), m3(t3.data(), 2, 2),
  207. m4(t4.data(), 2, 2);
  208. Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>
  209. expected = (m1 * m4) * (m3 - m1 * m2);
  210. VERIFY_IS_APPROX(result(0, 0), expected(0, 0));
  211. VERIFY_IS_APPROX(result(0, 1), expected(0, 1));
  212. VERIFY_IS_APPROX(result(1, 0), expected(1, 0));
  213. VERIFY_IS_APPROX(result(1, 1), expected(1, 1));
  214. }
  215. template<int DataLayout>
  216. static void test_expr()
  217. {
  218. Tensor<float, 2, DataLayout> mat1(2, 3);
  219. Tensor<float, 2, DataLayout> mat2(3, 2);
  220. mat1.setRandom();
  221. mat2.setRandom();
  222. Tensor<float, 2, DataLayout> mat3(2,2);
  223. Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
  224. mat3 = mat1.contract(mat2, dims);
  225. VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
  226. VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
  227. VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
  228. VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
  229. }
  230. template<int DataLayout>
  231. static void test_out_of_order_contraction()
  232. {
  233. Tensor<float, 3, DataLayout> mat1(2, 2, 2);
  234. Tensor<float, 3, DataLayout> mat2(2, 2, 2);
  235. mat1.setRandom();
  236. mat2.setRandom();
  237. Tensor<float, 2, DataLayout> mat3(2, 2);
  238. Eigen::array<DimPair, 2> dims = {{DimPair(2, 0), DimPair(0, 2)}};
  239. mat3 = mat1.contract(mat2, dims);
  240. VERIFY_IS_APPROX(mat3(0, 0),
  241. mat1(0,0,0)*mat2(0,0,0) + mat1(1,0,0)*mat2(0,0,1) +
  242. mat1(0,0,1)*mat2(1,0,0) + mat1(1,0,1)*mat2(1,0,1));
  243. VERIFY_IS_APPROX(mat3(1, 0),
  244. mat1(0,1,0)*mat2(0,0,0) + mat1(1,1,0)*mat2(0,0,1) +
  245. mat1(0,1,1)*mat2(1,0,0) + mat1(1,1,1)*mat2(1,0,1));
  246. VERIFY_IS_APPROX(mat3(0, 1),
  247. mat1(0,0,0)*mat2(0,1,0) + mat1(1,0,0)*mat2(0,1,1) +
  248. mat1(0,0,1)*mat2(1,1,0) + mat1(1,0,1)*mat2(1,1,1));
  249. VERIFY_IS_APPROX(mat3(1, 1),
  250. mat1(0,1,0)*mat2(0,1,0) + mat1(1,1,0)*mat2(0,1,1) +
  251. mat1(0,1,1)*mat2(1,1,0) + mat1(1,1,1)*mat2(1,1,1));
  252. Eigen::array<DimPair, 2> dims2 = {{DimPair(0, 2), DimPair(2, 0)}};
  253. mat3 = mat1.contract(mat2, dims2);
  254. VERIFY_IS_APPROX(mat3(0, 0),
  255. mat1(0,0,0)*mat2(0,0,0) + mat1(1,0,0)*mat2(0,0,1) +
  256. mat1(0,0,1)*mat2(1,0,0) + mat1(1,0,1)*mat2(1,0,1));
  257. VERIFY_IS_APPROX(mat3(1, 0),
  258. mat1(0,1,0)*mat2(0,0,0) + mat1(1,1,0)*mat2(0,0,1) +
  259. mat1(0,1,1)*mat2(1,0,0) + mat1(1,1,1)*mat2(1,0,1));
  260. VERIFY_IS_APPROX(mat3(0, 1),
  261. mat1(0,0,0)*mat2(0,1,0) + mat1(1,0,0)*mat2(0,1,1) +
  262. mat1(0,0,1)*mat2(1,1,0) + mat1(1,0,1)*mat2(1,1,1));
  263. VERIFY_IS_APPROX(mat3(1, 1),
  264. mat1(0,1,0)*mat2(0,1,0) + mat1(1,1,0)*mat2(0,1,1) +
  265. mat1(0,1,1)*mat2(1,1,0) + mat1(1,1,1)*mat2(1,1,1));
  266. }
  267. template<int DataLayout>
  268. static void test_consistency()
  269. {
  270. // this does something like testing (A*B)^T = (B^T * A^T)
  271. Tensor<float, 3, DataLayout> mat1(4, 3, 5);
  272. Tensor<float, 5, DataLayout> mat2(3, 2, 1, 5, 4);
  273. mat1.setRandom();
  274. mat2.setRandom();
  275. Tensor<float, 4, DataLayout> mat3(5, 2, 1, 5);
  276. Tensor<float, 4, DataLayout> mat4(2, 1, 5, 5);
  277. // contract on dimensions of size 4 and 3
  278. Eigen::array<DimPair, 2> dims1 = {{DimPair(0, 4), DimPair(1, 0)}};
  279. Eigen::array<DimPair, 2> dims2 = {{DimPair(4, 0), DimPair(0, 1)}};
  280. mat3 = mat1.contract(mat2, dims1);
  281. mat4 = mat2.contract(mat1, dims2);
  282. // check that these are equal except for ordering of dimensions
  283. if (DataLayout == ColMajor) {
  284. for (size_t i = 0; i < 5; i++) {
  285. for (size_t j = 0; j < 10; j++) {
  286. VERIFY_IS_APPROX(mat3.data()[i + 5 * j], mat4.data()[j + 10 * i]);
  287. }
  288. }
  289. } else {
  290. // Row major
  291. for (size_t i = 0; i < 5; i++) {
  292. for (size_t j = 0; j < 10; j++) {
  293. VERIFY_IS_APPROX(mat3.data()[10 * i + j], mat4.data()[i + 5 * j]);
  294. }
  295. }
  296. }
  297. }
  298. template<int DataLayout>
  299. static void test_large_contraction()
  300. {
  301. Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31);
  302. Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10);
  303. Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10);
  304. t_left.setRandom();
  305. t_right.setRandom();
  306. // Add a little offset so that the results won't be close to zero.
  307. t_left += t_left.constant(1.0f);
  308. t_right += t_right.constant(1.0f);
  309. typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
  310. MapXf m_left(t_left.data(), 1500, 248);
  311. MapXf m_right(t_right.data(), 248, 1400);
  312. Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400);
  313. // this contraction should be equivalent to a single matrix multiplication
  314. Eigen::array<DimPair, 2> dims = {{DimPair(2, 0), DimPair(3, 1)}};
  315. // compute results by separate methods
  316. t_result = t_left.contract(t_right, dims);
  317. m_result = m_left * m_right;
  318. for (int i = 0; i < t_result.dimensions().TotalSize(); i++) {
  319. VERIFY(&t_result.data()[i] != &m_result.data()[i]);
  320. VERIFY_IS_APPROX(t_result.data()[i], m_result.data()[i]);
  321. }
  322. }
  323. template<int DataLayout>
  324. static void test_matrix_vector()
  325. {
  326. Tensor<float, 2, DataLayout> t_left(30, 50);
  327. Tensor<float, 1, DataLayout> t_right(50);
  328. Tensor<float, 1, DataLayout> t_result(30);
  329. t_left.setRandom();
  330. t_right.setRandom();
  331. typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
  332. MapXf m_left(t_left.data(), 30, 50);
  333. MapXf m_right(t_right.data(), 50, 1);
  334. Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(30, 1);
  335. // this contraction should be equivalent to a single matrix multiplication
  336. Eigen::array<DimPair, 1> dims{{DimPair(1, 0)}};
  337. // compute results by separate methods
  338. t_result = t_left.contract(t_right, dims);
  339. m_result = m_left * m_right;
  340. for (int i = 0; i < t_result.dimensions().TotalSize(); i++) {
  341. VERIFY(internal::isApprox(t_result(i), m_result(i, 0), 1));
  342. }
  343. }
  344. template<int DataLayout>
  345. static void test_tensor_vector()
  346. {
  347. Tensor<float, 3, DataLayout> t_left(7, 13, 17);
  348. Tensor<float, 2, DataLayout> t_right(1, 7);
  349. t_left.setRandom();
  350. t_right.setRandom();
  351. typedef typename Tensor<float, 1, DataLayout>::DimensionPair DimensionPair;
  352. Eigen::array<DimensionPair, 1> dim_pair01{{{0, 1}}};
  353. Tensor<float, 3, DataLayout> t_result = t_left.contract(t_right, dim_pair01);
  354. typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
  355. MapXf m_left(t_left.data(), 7, 13*17);
  356. MapXf m_right(t_right.data(), 1, 7);
  357. Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result = m_left.transpose() * m_right.transpose();
  358. for (int i = 0; i < t_result.dimensions().TotalSize(); i++) {
  359. VERIFY(internal::isApprox(t_result(i), m_result(i, 0), 1));
  360. }
  361. }
  362. template<int DataLayout>
  363. static void test_small_blocking_factors()
  364. {
  365. Tensor<float, 4, DataLayout> t_left(30, 5, 3, 31);
  366. Tensor<float, 5, DataLayout> t_right(3, 31, 7, 20, 1);
  367. t_left.setRandom();
  368. t_right.setRandom();
  369. // Add a little offset so that the results won't be close to zero.
  370. t_left += t_left.constant(1.0f);
  371. t_right += t_right.constant(1.0f);
  372. // Force the cache sizes, which results in smaller blocking factors.
  373. Eigen::setCpuCacheSizes(896, 1920, 2944);
  374. // this contraction should be equivalent to a single matrix multiplication
  375. Eigen::array<DimPair, 2> dims = {{DimPair(2, 0), DimPair(3, 1)}};
  376. Tensor<float, 5, DataLayout> t_result;
  377. t_result = t_left.contract(t_right, dims);
  378. // compute result using a simple eigen matrix product
  379. Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> m_left(t_left.data(), 150, 93);
  380. Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> m_right(t_right.data(), 93, 140);
  381. Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result = m_left * m_right;
  382. for (int i = 0; i < t_result.dimensions().TotalSize(); i++) {
  383. VERIFY_IS_APPROX(t_result.data()[i], m_result.data()[i]);
  384. }
  385. }
  386. template<int DataLayout>
  387. static void test_tensor_product()
  388. {
  389. Tensor<float, 2, DataLayout> mat1(2, 3);
  390. Tensor<float, 2, DataLayout> mat2(4, 1);
  391. mat1.setRandom();
  392. mat2.setRandom();
  393. Eigen::array<DimPair, 0> dims;
  394. Tensor<float, 4, DataLayout> result = mat1.contract(mat2, dims);
  395. VERIFY_IS_EQUAL(result.dimension(0), 2);
  396. VERIFY_IS_EQUAL(result.dimension(1), 3);
  397. VERIFY_IS_EQUAL(result.dimension(2), 4);
  398. VERIFY_IS_EQUAL(result.dimension(3), 1);
  399. for (int i = 0; i < result.dimension(0); ++i) {
  400. for (int j = 0; j < result.dimension(1); ++j) {
  401. for (int k = 0; k < result.dimension(2); ++k) {
  402. for (int l = 0; l < result.dimension(3); ++l) {
  403. VERIFY_IS_APPROX(result(i, j, k, l), mat1(i, j) * mat2(k, l) );
  404. }
  405. }
  406. }
  407. }
  408. }
  409. template<int DataLayout>
  410. static void test_const_inputs()
  411. {
  412. Tensor<float, 2, DataLayout> in1(2, 3);
  413. Tensor<float, 2, DataLayout> in2(3, 2);
  414. in1.setRandom();
  415. in2.setRandom();
  416. TensorMap<Tensor<const float, 2, DataLayout> > mat1(in1.data(), 2, 3);
  417. TensorMap<Tensor<const float, 2, DataLayout> > mat2(in2.data(), 3, 2);
  418. Tensor<float, 2, DataLayout> mat3(2,2);
  419. Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
  420. mat3 = mat1.contract(mat2, dims);
  421. VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
  422. VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
  423. VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
  424. VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
  425. }
  426. // Apply Sqrt to all output elements.
  427. struct SqrtOutputKernel {
  428. template <typename Index, typename Scalar>
  429. EIGEN_ALWAYS_INLINE void operator()(
  430. const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
  431. const TensorContractionParams&, Index, Index, Index num_rows,
  432. Index num_cols) const {
  433. for (int i = 0; i < num_rows; ++i) {
  434. for (int j = 0; j < num_cols; ++j) {
  435. output_mapper(i, j) = std::sqrt(output_mapper(i, j));
  436. }
  437. }
  438. }
  439. };
  440. template <int DataLayout>
  441. static void test_large_contraction_with_output_kernel() {
  442. Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31);
  443. Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10);
  444. Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10);
  445. t_left.setRandom();
  446. t_right.setRandom();
  447. // Put trash in mat4 to verify contraction clears output memory.
  448. t_result.setRandom();
  449. // Add a little offset so that the results won't be close to zero.
  450. t_left += t_left.constant(1.0f);
  451. t_right += t_right.constant(1.0f);
  452. typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
  453. MapXf m_left(t_left.data(), 1500, 248);
  454. MapXf m_right(t_right.data(), 248, 1400);
  455. Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400);
  456. // this contraction should be equivalent to a single matrix multiplication
  457. Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}});
  458. // compute results by separate methods
  459. t_result = t_left.contract(t_right, dims, SqrtOutputKernel());
  460. m_result = m_left * m_right;
  461. for (std::ptrdiff_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
  462. VERIFY(&t_result.data()[i] != &m_result.data()[i]);
  463. VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i]));
  464. }
  465. }
  466. EIGEN_DECLARE_TEST(cxx11_tensor_contraction)
  467. {
  468. CALL_SUBTEST_1(test_evals<ColMajor>());
  469. CALL_SUBTEST_1(test_evals<RowMajor>());
  470. CALL_SUBTEST_1(test_scalar<ColMajor>());
  471. CALL_SUBTEST_1(test_scalar<RowMajor>());
  472. CALL_SUBTEST_2(test_multidims<ColMajor>());
  473. CALL_SUBTEST_2(test_multidims<RowMajor>());
  474. CALL_SUBTEST_2(test_holes<ColMajor>());
  475. CALL_SUBTEST_2(test_holes<RowMajor>());
  476. CALL_SUBTEST_3(test_full_redux<ColMajor>());
  477. CALL_SUBTEST_3(test_full_redux<RowMajor>());
  478. CALL_SUBTEST_3(test_contraction_of_contraction<ColMajor>());
  479. CALL_SUBTEST_3(test_contraction_of_contraction<RowMajor>());
  480. CALL_SUBTEST_4(test_expr<ColMajor>());
  481. CALL_SUBTEST_4(test_expr<RowMajor>());
  482. CALL_SUBTEST_4(test_out_of_order_contraction<ColMajor>());
  483. CALL_SUBTEST_4(test_out_of_order_contraction<RowMajor>());
  484. CALL_SUBTEST_5(test_consistency<ColMajor>());
  485. CALL_SUBTEST_5(test_consistency<RowMajor>());
  486. CALL_SUBTEST_5(test_large_contraction<ColMajor>());
  487. CALL_SUBTEST_5(test_large_contraction<RowMajor>());
  488. CALL_SUBTEST_6(test_matrix_vector<ColMajor>());
  489. CALL_SUBTEST_6(test_matrix_vector<RowMajor>());
  490. CALL_SUBTEST_6(test_tensor_vector<ColMajor>());
  491. CALL_SUBTEST_6(test_tensor_vector<RowMajor>());
  492. CALL_SUBTEST_7(test_small_blocking_factors<ColMajor>());
  493. CALL_SUBTEST_7(test_small_blocking_factors<RowMajor>());
  494. CALL_SUBTEST_7(test_tensor_product<ColMajor>());
  495. CALL_SUBTEST_7(test_tensor_product<RowMajor>());
  496. CALL_SUBTEST_8(test_const_inputs<ColMajor>());
  497. CALL_SUBTEST_8(test_const_inputs<RowMajor>());
  498. CALL_SUBTEST_8(test_large_contraction_with_output_kernel<ColMajor>());
  499. CALL_SUBTEST_8(test_large_contraction_with_output_kernel<RowMajor>());
  500. // Force CMake to split this test.
  501. // EIGEN_SUFFIXES;1;2;3;4;5;6;7;8
  502. }