cxx11_tensor_simple.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2013 Christian Seiler <christian@iwakd.de>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. using Eigen::RowMajor;
  13. static void test_0d()
  14. {
  15. Tensor<int, 0> scalar1;
  16. Tensor<int, 0, RowMajor> scalar2;
  17. Tensor<int, 0> scalar3;
  18. Tensor<int, 0, RowMajor> scalar4;
  19. scalar3.resize();
  20. scalar4.resize();
  21. scalar1() = 7;
  22. scalar2() = 13;
  23. scalar3.setValues(17);
  24. scalar4.setZero();
  25. VERIFY_IS_EQUAL(scalar1.rank(), 0);
  26. VERIFY_IS_EQUAL(scalar1.size(), 1);
  27. VERIFY_IS_EQUAL(scalar1(), 7);
  28. VERIFY_IS_EQUAL(scalar2(), 13);
  29. VERIFY_IS_EQUAL(scalar3(), 17);
  30. VERIFY_IS_EQUAL(scalar4(), 0);
  31. Tensor<int, 0> scalar5(scalar1);
  32. VERIFY_IS_EQUAL(scalar5(), 7);
  33. VERIFY_IS_EQUAL(scalar5.data()[0], 7);
  34. }
  35. static void test_1d()
  36. {
  37. Tensor<int, 1> vec1(6);
  38. Tensor<int, 1, RowMajor> vec2(6);
  39. Tensor<int, 1> vec3;
  40. Tensor<int, 1, RowMajor> vec4;
  41. vec3.resize(6);
  42. vec4.resize(6);
  43. vec1(0) = 4; vec2(0) = 0; vec3(0) = 5;
  44. vec1(1) = 8; vec2(1) = 1; vec3(1) = 4;
  45. vec1(2) = 15; vec2(2) = 2; vec3(2) = 3;
  46. vec1(3) = 16; vec2(3) = 3; vec3(3) = 2;
  47. vec1(4) = 23; vec2(4) = 4; vec3(4) = 1;
  48. vec1(5) = 42; vec2(5) = 5; vec3(5) = 0;
  49. vec4.setZero();
  50. VERIFY_IS_EQUAL((vec1.rank()), 1);
  51. VERIFY_IS_EQUAL((vec1.size()), 6);
  52. VERIFY_IS_EQUAL((vec1.dimensions()[0]), 6);
  53. VERIFY_IS_EQUAL((vec1[0]), 4);
  54. VERIFY_IS_EQUAL((vec1[1]), 8);
  55. VERIFY_IS_EQUAL((vec1[2]), 15);
  56. VERIFY_IS_EQUAL((vec1[3]), 16);
  57. VERIFY_IS_EQUAL((vec1[4]), 23);
  58. VERIFY_IS_EQUAL((vec1[5]), 42);
  59. VERIFY_IS_EQUAL((vec2[0]), 0);
  60. VERIFY_IS_EQUAL((vec2[1]), 1);
  61. VERIFY_IS_EQUAL((vec2[2]), 2);
  62. VERIFY_IS_EQUAL((vec2[3]), 3);
  63. VERIFY_IS_EQUAL((vec2[4]), 4);
  64. VERIFY_IS_EQUAL((vec2[5]), 5);
  65. VERIFY_IS_EQUAL((vec3[0]), 5);
  66. VERIFY_IS_EQUAL((vec3[1]), 4);
  67. VERIFY_IS_EQUAL((vec3[2]), 3);
  68. VERIFY_IS_EQUAL((vec3[3]), 2);
  69. VERIFY_IS_EQUAL((vec3[4]), 1);
  70. VERIFY_IS_EQUAL((vec3[5]), 0);
  71. VERIFY_IS_EQUAL((vec4[0]), 0);
  72. VERIFY_IS_EQUAL((vec4[1]), 0);
  73. VERIFY_IS_EQUAL((vec4[2]), 0);
  74. VERIFY_IS_EQUAL((vec4[3]), 0);
  75. VERIFY_IS_EQUAL((vec4[4]), 0);
  76. VERIFY_IS_EQUAL((vec4[5]), 0);
  77. Tensor<int, 1> vec5(vec1);
  78. VERIFY_IS_EQUAL((vec5(0)), 4);
  79. VERIFY_IS_EQUAL((vec5(1)), 8);
  80. VERIFY_IS_EQUAL((vec5(2)), 15);
  81. VERIFY_IS_EQUAL((vec5(3)), 16);
  82. VERIFY_IS_EQUAL((vec5(4)), 23);
  83. VERIFY_IS_EQUAL((vec5(5)), 42);
  84. VERIFY_IS_EQUAL((vec5.data()[0]), 4);
  85. VERIFY_IS_EQUAL((vec5.data()[1]), 8);
  86. VERIFY_IS_EQUAL((vec5.data()[2]), 15);
  87. VERIFY_IS_EQUAL((vec5.data()[3]), 16);
  88. VERIFY_IS_EQUAL((vec5.data()[4]), 23);
  89. VERIFY_IS_EQUAL((vec5.data()[5]), 42);
  90. }
  91. static void test_2d()
  92. {
  93. Tensor<int, 2> mat1(2,3);
  94. Tensor<int, 2, RowMajor> mat2(2,3);
  95. mat1(0,0) = 0;
  96. mat1(0,1) = 1;
  97. mat1(0,2) = 2;
  98. mat1(1,0) = 3;
  99. mat1(1,1) = 4;
  100. mat1(1,2) = 5;
  101. mat2(0,0) = 0;
  102. mat2(0,1) = 1;
  103. mat2(0,2) = 2;
  104. mat2(1,0) = 3;
  105. mat2(1,1) = 4;
  106. mat2(1,2) = 5;
  107. VERIFY_IS_EQUAL((mat1.rank()), 2);
  108. VERIFY_IS_EQUAL((mat1.size()), 6);
  109. VERIFY_IS_EQUAL((mat1.dimensions()[0]), 2);
  110. VERIFY_IS_EQUAL((mat1.dimensions()[1]), 3);
  111. VERIFY_IS_EQUAL((mat2.rank()), 2);
  112. VERIFY_IS_EQUAL((mat2.size()), 6);
  113. VERIFY_IS_EQUAL((mat2.dimensions()[0]), 2);
  114. VERIFY_IS_EQUAL((mat2.dimensions()[1]), 3);
  115. VERIFY_IS_EQUAL((mat1.data()[0]), 0);
  116. VERIFY_IS_EQUAL((mat1.data()[1]), 3);
  117. VERIFY_IS_EQUAL((mat1.data()[2]), 1);
  118. VERIFY_IS_EQUAL((mat1.data()[3]), 4);
  119. VERIFY_IS_EQUAL((mat1.data()[4]), 2);
  120. VERIFY_IS_EQUAL((mat1.data()[5]), 5);
  121. VERIFY_IS_EQUAL((mat2.data()[0]), 0);
  122. VERIFY_IS_EQUAL((mat2.data()[1]), 1);
  123. VERIFY_IS_EQUAL((mat2.data()[2]), 2);
  124. VERIFY_IS_EQUAL((mat2.data()[3]), 3);
  125. VERIFY_IS_EQUAL((mat2.data()[4]), 4);
  126. VERIFY_IS_EQUAL((mat2.data()[5]), 5);
  127. }
  128. static void test_3d()
  129. {
  130. Tensor<int, 3> epsilon(3,3,3);
  131. epsilon.setZero();
  132. epsilon(0,1,2) = epsilon(2,0,1) = epsilon(1,2,0) = 1;
  133. epsilon(2,1,0) = epsilon(0,2,1) = epsilon(1,0,2) = -1;
  134. VERIFY_IS_EQUAL((epsilon.size()), 27);
  135. VERIFY_IS_EQUAL((epsilon.dimensions()[0]), 3);
  136. VERIFY_IS_EQUAL((epsilon.dimensions()[1]), 3);
  137. VERIFY_IS_EQUAL((epsilon.dimensions()[2]), 3);
  138. VERIFY_IS_EQUAL((epsilon(0,0,0)), 0);
  139. VERIFY_IS_EQUAL((epsilon(0,0,1)), 0);
  140. VERIFY_IS_EQUAL((epsilon(0,0,2)), 0);
  141. VERIFY_IS_EQUAL((epsilon(0,1,0)), 0);
  142. VERIFY_IS_EQUAL((epsilon(0,1,1)), 0);
  143. VERIFY_IS_EQUAL((epsilon(0,2,0)), 0);
  144. VERIFY_IS_EQUAL((epsilon(0,2,2)), 0);
  145. VERIFY_IS_EQUAL((epsilon(1,0,0)), 0);
  146. VERIFY_IS_EQUAL((epsilon(1,0,1)), 0);
  147. VERIFY_IS_EQUAL((epsilon(1,1,0)), 0);
  148. VERIFY_IS_EQUAL((epsilon(1,1,1)), 0);
  149. VERIFY_IS_EQUAL((epsilon(1,1,2)), 0);
  150. VERIFY_IS_EQUAL((epsilon(1,2,1)), 0);
  151. VERIFY_IS_EQUAL((epsilon(1,2,2)), 0);
  152. VERIFY_IS_EQUAL((epsilon(2,0,0)), 0);
  153. VERIFY_IS_EQUAL((epsilon(2,0,2)), 0);
  154. VERIFY_IS_EQUAL((epsilon(2,1,1)), 0);
  155. VERIFY_IS_EQUAL((epsilon(2,1,2)), 0);
  156. VERIFY_IS_EQUAL((epsilon(2,2,0)), 0);
  157. VERIFY_IS_EQUAL((epsilon(2,2,1)), 0);
  158. VERIFY_IS_EQUAL((epsilon(2,2,2)), 0);
  159. VERIFY_IS_EQUAL((epsilon(0,1,2)), 1);
  160. VERIFY_IS_EQUAL((epsilon(2,0,1)), 1);
  161. VERIFY_IS_EQUAL((epsilon(1,2,0)), 1);
  162. VERIFY_IS_EQUAL((epsilon(2,1,0)), -1);
  163. VERIFY_IS_EQUAL((epsilon(0,2,1)), -1);
  164. VERIFY_IS_EQUAL((epsilon(1,0,2)), -1);
  165. array<Eigen::DenseIndex, 3> dims;
  166. dims[0] = 2;
  167. dims[1] = 3;
  168. dims[2] = 4;
  169. Tensor<int, 3> t1(dims);
  170. Tensor<int, 3, RowMajor> t2(dims);
  171. VERIFY_IS_EQUAL((t1.size()), 24);
  172. VERIFY_IS_EQUAL((t1.dimensions()[0]), 2);
  173. VERIFY_IS_EQUAL((t1.dimensions()[1]), 3);
  174. VERIFY_IS_EQUAL((t1.dimensions()[2]), 4);
  175. VERIFY_IS_EQUAL((t2.size()), 24);
  176. VERIFY_IS_EQUAL((t2.dimensions()[0]), 2);
  177. VERIFY_IS_EQUAL((t2.dimensions()[1]), 3);
  178. VERIFY_IS_EQUAL((t2.dimensions()[2]), 4);
  179. for (int i = 0; i < 2; i++) {
  180. for (int j = 0; j < 3; j++) {
  181. for (int k = 0; k < 4; k++) {
  182. t1(i, j, k) = 100 * i + 10 * j + k;
  183. t2(i, j, k) = 100 * i + 10 * j + k;
  184. }
  185. }
  186. }
  187. VERIFY_IS_EQUAL((t1.data()[0]), 0);
  188. VERIFY_IS_EQUAL((t1.data()[1]), 100);
  189. VERIFY_IS_EQUAL((t1.data()[2]), 10);
  190. VERIFY_IS_EQUAL((t1.data()[3]), 110);
  191. VERIFY_IS_EQUAL((t1.data()[4]), 20);
  192. VERIFY_IS_EQUAL((t1.data()[5]), 120);
  193. VERIFY_IS_EQUAL((t1.data()[6]), 1);
  194. VERIFY_IS_EQUAL((t1.data()[7]), 101);
  195. VERIFY_IS_EQUAL((t1.data()[8]), 11);
  196. VERIFY_IS_EQUAL((t1.data()[9]), 111);
  197. VERIFY_IS_EQUAL((t1.data()[10]), 21);
  198. VERIFY_IS_EQUAL((t1.data()[11]), 121);
  199. VERIFY_IS_EQUAL((t1.data()[12]), 2);
  200. VERIFY_IS_EQUAL((t1.data()[13]), 102);
  201. VERIFY_IS_EQUAL((t1.data()[14]), 12);
  202. VERIFY_IS_EQUAL((t1.data()[15]), 112);
  203. VERIFY_IS_EQUAL((t1.data()[16]), 22);
  204. VERIFY_IS_EQUAL((t1.data()[17]), 122);
  205. VERIFY_IS_EQUAL((t1.data()[18]), 3);
  206. VERIFY_IS_EQUAL((t1.data()[19]), 103);
  207. VERIFY_IS_EQUAL((t1.data()[20]), 13);
  208. VERIFY_IS_EQUAL((t1.data()[21]), 113);
  209. VERIFY_IS_EQUAL((t1.data()[22]), 23);
  210. VERIFY_IS_EQUAL((t1.data()[23]), 123);
  211. VERIFY_IS_EQUAL((t2.data()[0]), 0);
  212. VERIFY_IS_EQUAL((t2.data()[1]), 1);
  213. VERIFY_IS_EQUAL((t2.data()[2]), 2);
  214. VERIFY_IS_EQUAL((t2.data()[3]), 3);
  215. VERIFY_IS_EQUAL((t2.data()[4]), 10);
  216. VERIFY_IS_EQUAL((t2.data()[5]), 11);
  217. VERIFY_IS_EQUAL((t2.data()[6]), 12);
  218. VERIFY_IS_EQUAL((t2.data()[7]), 13);
  219. VERIFY_IS_EQUAL((t2.data()[8]), 20);
  220. VERIFY_IS_EQUAL((t2.data()[9]), 21);
  221. VERIFY_IS_EQUAL((t2.data()[10]), 22);
  222. VERIFY_IS_EQUAL((t2.data()[11]), 23);
  223. VERIFY_IS_EQUAL((t2.data()[12]), 100);
  224. VERIFY_IS_EQUAL((t2.data()[13]), 101);
  225. VERIFY_IS_EQUAL((t2.data()[14]), 102);
  226. VERIFY_IS_EQUAL((t2.data()[15]), 103);
  227. VERIFY_IS_EQUAL((t2.data()[16]), 110);
  228. VERIFY_IS_EQUAL((t2.data()[17]), 111);
  229. VERIFY_IS_EQUAL((t2.data()[18]), 112);
  230. VERIFY_IS_EQUAL((t2.data()[19]), 113);
  231. VERIFY_IS_EQUAL((t2.data()[20]), 120);
  232. VERIFY_IS_EQUAL((t2.data()[21]), 121);
  233. VERIFY_IS_EQUAL((t2.data()[22]), 122);
  234. VERIFY_IS_EQUAL((t2.data()[23]), 123);
  235. }
  236. static void test_simple_assign()
  237. {
  238. Tensor<int, 3> epsilon(3,3,3);
  239. epsilon.setZero();
  240. epsilon(0,1,2) = epsilon(2,0,1) = epsilon(1,2,0) = 1;
  241. epsilon(2,1,0) = epsilon(0,2,1) = epsilon(1,0,2) = -1;
  242. Tensor<int, 3> e2(3,3,3);
  243. e2.setZero();
  244. VERIFY_IS_EQUAL((e2(1,2,0)), 0);
  245. e2 = epsilon;
  246. VERIFY_IS_EQUAL((e2(1,2,0)), 1);
  247. VERIFY_IS_EQUAL((e2(0,1,2)), 1);
  248. VERIFY_IS_EQUAL((e2(2,0,1)), 1);
  249. VERIFY_IS_EQUAL((e2(2,1,0)), -1);
  250. VERIFY_IS_EQUAL((e2(0,2,1)), -1);
  251. VERIFY_IS_EQUAL((e2(1,0,2)), -1);
  252. }
  253. static void test_resize()
  254. {
  255. Tensor<int, 3> epsilon;
  256. epsilon.resize(2,3,7);
  257. VERIFY_IS_EQUAL(epsilon.dimension(0), 2);
  258. VERIFY_IS_EQUAL(epsilon.dimension(1), 3);
  259. VERIFY_IS_EQUAL(epsilon.dimension(2), 7);
  260. VERIFY_IS_EQUAL(epsilon.size(), 2*3*7);
  261. const int* old_data = epsilon.data();
  262. epsilon.resize(3,2,7);
  263. VERIFY_IS_EQUAL(epsilon.dimension(0), 3);
  264. VERIFY_IS_EQUAL(epsilon.dimension(1), 2);
  265. VERIFY_IS_EQUAL(epsilon.dimension(2), 7);
  266. VERIFY_IS_EQUAL(epsilon.size(), 2*3*7);
  267. VERIFY_IS_EQUAL(epsilon.data(), old_data);
  268. epsilon.resize(3,5,7);
  269. VERIFY_IS_EQUAL(epsilon.dimension(0), 3);
  270. VERIFY_IS_EQUAL(epsilon.dimension(1), 5);
  271. VERIFY_IS_EQUAL(epsilon.dimension(2), 7);
  272. VERIFY_IS_EQUAL(epsilon.size(), 3*5*7);
  273. }
  274. EIGEN_DECLARE_TEST(cxx11_tensor_simple)
  275. {
  276. CALL_SUBTEST(test_0d());
  277. CALL_SUBTEST(test_1d());
  278. CALL_SUBTEST(test_2d());
  279. CALL_SUBTEST(test_3d());
  280. CALL_SUBTEST(test_simple_assign());
  281. CALL_SUBTEST(test_resize());
  282. }