cxx11_tensor_assign.cpp 9.5 KB


  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. using Eigen::RowMajor;
  13. static void test_1d()
  14. {
  15. Tensor<int, 1> vec1(6);
  16. Tensor<int, 1, RowMajor> vec2(6);
  17. vec1(0) = 4; vec2(0) = 0;
  18. vec1(1) = 8; vec2(1) = 1;
  19. vec1(2) = 15; vec2(2) = 2;
  20. vec1(3) = 16; vec2(3) = 3;
  21. vec1(4) = 23; vec2(4) = 4;
  22. vec1(5) = 42; vec2(5) = 5;
  23. int col_major[6];
  24. int row_major[6];
  25. memset(col_major, 0, 6*sizeof(int));
  26. memset(row_major, 0, 6*sizeof(int));
  27. TensorMap<Tensor<int, 1> > vec3(col_major, 6);
  28. TensorMap<Tensor<int, 1, RowMajor> > vec4(row_major, 6);
  29. vec3 = vec1;
  30. vec4 = vec2;
  31. VERIFY_IS_EQUAL(vec3(0), 4);
  32. VERIFY_IS_EQUAL(vec3(1), 8);
  33. VERIFY_IS_EQUAL(vec3(2), 15);
  34. VERIFY_IS_EQUAL(vec3(3), 16);
  35. VERIFY_IS_EQUAL(vec3(4), 23);
  36. VERIFY_IS_EQUAL(vec3(5), 42);
  37. VERIFY_IS_EQUAL(vec4(0), 0);
  38. VERIFY_IS_EQUAL(vec4(1), 1);
  39. VERIFY_IS_EQUAL(vec4(2), 2);
  40. VERIFY_IS_EQUAL(vec4(3), 3);
  41. VERIFY_IS_EQUAL(vec4(4), 4);
  42. VERIFY_IS_EQUAL(vec4(5), 5);
  43. vec1.setZero();
  44. vec2.setZero();
  45. vec1 = vec3;
  46. vec2 = vec4;
  47. VERIFY_IS_EQUAL(vec1(0), 4);
  48. VERIFY_IS_EQUAL(vec1(1), 8);
  49. VERIFY_IS_EQUAL(vec1(2), 15);
  50. VERIFY_IS_EQUAL(vec1(3), 16);
  51. VERIFY_IS_EQUAL(vec1(4), 23);
  52. VERIFY_IS_EQUAL(vec1(5), 42);
  53. VERIFY_IS_EQUAL(vec2(0), 0);
  54. VERIFY_IS_EQUAL(vec2(1), 1);
  55. VERIFY_IS_EQUAL(vec2(2), 2);
  56. VERIFY_IS_EQUAL(vec2(3), 3);
  57. VERIFY_IS_EQUAL(vec2(4), 4);
  58. VERIFY_IS_EQUAL(vec2(5), 5);
  59. }
  60. static void test_2d()
  61. {
  62. Tensor<int, 2> mat1(2,3);
  63. Tensor<int, 2, RowMajor> mat2(2,3);
  64. mat1(0,0) = 0;
  65. mat1(0,1) = 1;
  66. mat1(0,2) = 2;
  67. mat1(1,0) = 3;
  68. mat1(1,1) = 4;
  69. mat1(1,2) = 5;
  70. mat2(0,0) = 0;
  71. mat2(0,1) = 1;
  72. mat2(0,2) = 2;
  73. mat2(1,0) = 3;
  74. mat2(1,1) = 4;
  75. mat2(1,2) = 5;
  76. int col_major[6];
  77. int row_major[6];
  78. memset(col_major, 0, 6*sizeof(int));
  79. memset(row_major, 0, 6*sizeof(int));
  80. TensorMap<Tensor<int, 2> > mat3(row_major, 2, 3);
  81. TensorMap<Tensor<int, 2, RowMajor> > mat4(col_major, 2, 3);
  82. mat3 = mat1;
  83. mat4 = mat2;
  84. VERIFY_IS_EQUAL(mat3(0,0), 0);
  85. VERIFY_IS_EQUAL(mat3(0,1), 1);
  86. VERIFY_IS_EQUAL(mat3(0,2), 2);
  87. VERIFY_IS_EQUAL(mat3(1,0), 3);
  88. VERIFY_IS_EQUAL(mat3(1,1), 4);
  89. VERIFY_IS_EQUAL(mat3(1,2), 5);
  90. VERIFY_IS_EQUAL(mat4(0,0), 0);
  91. VERIFY_IS_EQUAL(mat4(0,1), 1);
  92. VERIFY_IS_EQUAL(mat4(0,2), 2);
  93. VERIFY_IS_EQUAL(mat4(1,0), 3);
  94. VERIFY_IS_EQUAL(mat4(1,1), 4);
  95. VERIFY_IS_EQUAL(mat4(1,2), 5);
  96. mat1.setZero();
  97. mat2.setZero();
  98. mat1 = mat3;
  99. mat2 = mat4;
  100. VERIFY_IS_EQUAL(mat1(0,0), 0);
  101. VERIFY_IS_EQUAL(mat1(0,1), 1);
  102. VERIFY_IS_EQUAL(mat1(0,2), 2);
  103. VERIFY_IS_EQUAL(mat1(1,0), 3);
  104. VERIFY_IS_EQUAL(mat1(1,1), 4);
  105. VERIFY_IS_EQUAL(mat1(1,2), 5);
  106. VERIFY_IS_EQUAL(mat2(0,0), 0);
  107. VERIFY_IS_EQUAL(mat2(0,1), 1);
  108. VERIFY_IS_EQUAL(mat2(0,2), 2);
  109. VERIFY_IS_EQUAL(mat2(1,0), 3);
  110. VERIFY_IS_EQUAL(mat2(1,1), 4);
  111. VERIFY_IS_EQUAL(mat2(1,2), 5);
  112. }
  113. static void test_3d()
  114. {
  115. Tensor<int, 3> mat1(2,3,7);
  116. Tensor<int, 3, RowMajor> mat2(2,3,7);
  117. int val = 0;
  118. for (int i = 0; i < 2; ++i) {
  119. for (int j = 0; j < 3; ++j) {
  120. for (int k = 0; k < 7; ++k) {
  121. mat1(i,j,k) = val;
  122. mat2(i,j,k) = val;
  123. val++;
  124. }
  125. }
  126. }
  127. int col_major[2*3*7];
  128. int row_major[2*3*7];
  129. memset(col_major, 0, 2*3*7*sizeof(int));
  130. memset(row_major, 0, 2*3*7*sizeof(int));
  131. TensorMap<Tensor<int, 3> > mat3(col_major, 2, 3, 7);
  132. TensorMap<Tensor<int, 3, RowMajor> > mat4(row_major, 2, 3, 7);
  133. mat3 = mat1;
  134. mat4 = mat2;
  135. val = 0;
  136. for (int i = 0; i < 2; ++i) {
  137. for (int j = 0; j < 3; ++j) {
  138. for (int k = 0; k < 7; ++k) {
  139. VERIFY_IS_EQUAL(mat3(i,j,k), val);
  140. VERIFY_IS_EQUAL(mat4(i,j,k), val);
  141. val++;
  142. }
  143. }
  144. }
  145. mat1.setZero();
  146. mat2.setZero();
  147. mat1 = mat3;
  148. mat2 = mat4;
  149. val = 0;
  150. for (int i = 0; i < 2; ++i) {
  151. for (int j = 0; j < 3; ++j) {
  152. for (int k = 0; k < 7; ++k) {
  153. VERIFY_IS_EQUAL(mat1(i,j,k), val);
  154. VERIFY_IS_EQUAL(mat2(i,j,k), val);
  155. val++;
  156. }
  157. }
  158. }
  159. }
  160. static void test_same_type()
  161. {
  162. Tensor<int, 1> orig_tensor(5);
  163. Tensor<int, 1> dest_tensor(5);
  164. orig_tensor.setRandom();
  165. dest_tensor.setRandom();
  166. int* orig_data = orig_tensor.data();
  167. int* dest_data = dest_tensor.data();
  168. dest_tensor = orig_tensor;
  169. VERIFY_IS_EQUAL(orig_tensor.data(), orig_data);
  170. VERIFY_IS_EQUAL(dest_tensor.data(), dest_data);
  171. for (int i = 0; i < 5; ++i) {
  172. VERIFY_IS_EQUAL(dest_tensor(i), orig_tensor(i));
  173. }
  174. TensorFixedSize<int, Sizes<5> > orig_array;
  175. TensorFixedSize<int, Sizes<5> > dest_array;
  176. orig_array.setRandom();
  177. dest_array.setRandom();
  178. orig_data = orig_array.data();
  179. dest_data = dest_array.data();
  180. dest_array = orig_array;
  181. VERIFY_IS_EQUAL(orig_array.data(), orig_data);
  182. VERIFY_IS_EQUAL(dest_array.data(), dest_data);
  183. for (int i = 0; i < 5; ++i) {
  184. VERIFY_IS_EQUAL(dest_array(i), orig_array(i));
  185. }
  186. int orig[5] = {1, 2, 3, 4, 5};
  187. int dest[5] = {6, 7, 8, 9, 10};
  188. TensorMap<Tensor<int, 1> > orig_map(orig, 5);
  189. TensorMap<Tensor<int, 1> > dest_map(dest, 5);
  190. orig_data = orig_map.data();
  191. dest_data = dest_map.data();
  192. dest_map = orig_map;
  193. VERIFY_IS_EQUAL(orig_map.data(), orig_data);
  194. VERIFY_IS_EQUAL(dest_map.data(), dest_data);
  195. for (int i = 0; i < 5; ++i) {
  196. VERIFY_IS_EQUAL(dest[i], i+1);
  197. }
  198. }
  199. static void test_auto_resize()
  200. {
  201. Tensor<int, 1> tensor1;
  202. Tensor<int, 1> tensor2(3);
  203. Tensor<int, 1> tensor3(5);
  204. Tensor<int, 1> tensor4(7);
  205. Tensor<int, 1> new_tensor(5);
  206. new_tensor.setRandom();
  207. tensor1 = tensor2 = tensor3 = tensor4 = new_tensor;
  208. VERIFY_IS_EQUAL(tensor1.dimension(0), new_tensor.dimension(0));
  209. VERIFY_IS_EQUAL(tensor2.dimension(0), new_tensor.dimension(0));
  210. VERIFY_IS_EQUAL(tensor3.dimension(0), new_tensor.dimension(0));
  211. VERIFY_IS_EQUAL(tensor4.dimension(0), new_tensor.dimension(0));
  212. for (int i = 0; i < new_tensor.dimension(0); ++i) {
  213. VERIFY_IS_EQUAL(tensor1(i), new_tensor(i));
  214. VERIFY_IS_EQUAL(tensor2(i), new_tensor(i));
  215. VERIFY_IS_EQUAL(tensor3(i), new_tensor(i));
  216. VERIFY_IS_EQUAL(tensor4(i), new_tensor(i));
  217. }
  218. }
  219. static void test_compound_assign()
  220. {
  221. Tensor<int, 1> start_tensor(10);
  222. Tensor<int, 1> offset_tensor(10);
  223. start_tensor.setRandom();
  224. offset_tensor.setRandom();
  225. Tensor<int, 1> tensor = start_tensor;
  226. tensor += offset_tensor;
  227. for (int i = 0; i < 10; ++i) {
  228. VERIFY_IS_EQUAL(tensor(i), start_tensor(i) + offset_tensor(i));
  229. }
  230. tensor = start_tensor;
  231. tensor -= offset_tensor;
  232. for (int i = 0; i < 10; ++i) {
  233. VERIFY_IS_EQUAL(tensor(i), start_tensor(i) - offset_tensor(i));
  234. }
  235. tensor = start_tensor;
  236. tensor *= offset_tensor;
  237. for (int i = 0; i < 10; ++i) {
  238. VERIFY_IS_EQUAL(tensor(i), start_tensor(i) * offset_tensor(i));
  239. }
  240. tensor = start_tensor;
  241. tensor /= offset_tensor;
  242. for (int i = 0; i < 10; ++i) {
  243. VERIFY_IS_EQUAL(tensor(i), start_tensor(i) / offset_tensor(i));
  244. }
  245. }
  246. static void test_std_initializers_tensor() {
  247. #if EIGEN_HAS_VARIADIC_TEMPLATES
  248. Tensor<int, 1> a(3);
  249. a.setValues({0, 1, 2});
  250. VERIFY_IS_EQUAL(a(0), 0);
  251. VERIFY_IS_EQUAL(a(1), 1);
  252. VERIFY_IS_EQUAL(a(2), 2);
  253. // It fills the top-left slice.
  254. a.setValues({10, 20});
  255. VERIFY_IS_EQUAL(a(0), 10);
  256. VERIFY_IS_EQUAL(a(1), 20);
  257. VERIFY_IS_EQUAL(a(2), 2);
  258. // Chaining.
  259. Tensor<int, 1> a2(3);
  260. a2 = a.setValues({100, 200, 300});
  261. VERIFY_IS_EQUAL(a(0), 100);
  262. VERIFY_IS_EQUAL(a(1), 200);
  263. VERIFY_IS_EQUAL(a(2), 300);
  264. VERIFY_IS_EQUAL(a2(0), 100);
  265. VERIFY_IS_EQUAL(a2(1), 200);
  266. VERIFY_IS_EQUAL(a2(2), 300);
  267. Tensor<int, 2> b(2, 3);
  268. b.setValues({{0, 1, 2}, {3, 4, 5}});
  269. VERIFY_IS_EQUAL(b(0, 0), 0);
  270. VERIFY_IS_EQUAL(b(0, 1), 1);
  271. VERIFY_IS_EQUAL(b(0, 2), 2);
  272. VERIFY_IS_EQUAL(b(1, 0), 3);
  273. VERIFY_IS_EQUAL(b(1, 1), 4);
  274. VERIFY_IS_EQUAL(b(1, 2), 5);
  275. // It fills the top-left slice.
  276. b.setValues({{10, 20}, {30}});
  277. VERIFY_IS_EQUAL(b(0, 0), 10);
  278. VERIFY_IS_EQUAL(b(0, 1), 20);
  279. VERIFY_IS_EQUAL(b(0, 2), 2);
  280. VERIFY_IS_EQUAL(b(1, 0), 30);
  281. VERIFY_IS_EQUAL(b(1, 1), 4);
  282. VERIFY_IS_EQUAL(b(1, 2), 5);
  283. Eigen::Tensor<int, 3> c(3, 2, 4);
  284. c.setValues({{{0, 1, 2, 3}, {4, 5, 6, 7}},
  285. {{10, 11, 12, 13}, {14, 15, 16, 17}},
  286. {{20, 21, 22, 23}, {24, 25, 26, 27}}});
  287. VERIFY_IS_EQUAL(c(0, 0, 0), 0);
  288. VERIFY_IS_EQUAL(c(0, 0, 1), 1);
  289. VERIFY_IS_EQUAL(c(0, 0, 2), 2);
  290. VERIFY_IS_EQUAL(c(0, 0, 3), 3);
  291. VERIFY_IS_EQUAL(c(0, 1, 0), 4);
  292. VERIFY_IS_EQUAL(c(0, 1, 1), 5);
  293. VERIFY_IS_EQUAL(c(0, 1, 2), 6);
  294. VERIFY_IS_EQUAL(c(0, 1, 3), 7);
  295. VERIFY_IS_EQUAL(c(1, 0, 0), 10);
  296. VERIFY_IS_EQUAL(c(1, 0, 1), 11);
  297. VERIFY_IS_EQUAL(c(1, 0, 2), 12);
  298. VERIFY_IS_EQUAL(c(1, 0, 3), 13);
  299. VERIFY_IS_EQUAL(c(1, 1, 0), 14);
  300. VERIFY_IS_EQUAL(c(1, 1, 1), 15);
  301. VERIFY_IS_EQUAL(c(1, 1, 2), 16);
  302. VERIFY_IS_EQUAL(c(1, 1, 3), 17);
  303. VERIFY_IS_EQUAL(c(2, 0, 0), 20);
  304. VERIFY_IS_EQUAL(c(2, 0, 1), 21);
  305. VERIFY_IS_EQUAL(c(2, 0, 2), 22);
  306. VERIFY_IS_EQUAL(c(2, 0, 3), 23);
  307. VERIFY_IS_EQUAL(c(2, 1, 0), 24);
  308. VERIFY_IS_EQUAL(c(2, 1, 1), 25);
  309. VERIFY_IS_EQUAL(c(2, 1, 2), 26);
  310. VERIFY_IS_EQUAL(c(2, 1, 3), 27);
  311. #endif // EIGEN_HAS_VARIADIC_TEMPLATES
  312. }
  313. EIGEN_DECLARE_TEST(cxx11_tensor_assign)
  314. {
  315. CALL_SUBTEST(test_1d());
  316. CALL_SUBTEST(test_2d());
  317. CALL_SUBTEST(test_3d());
  318. CALL_SUBTEST(test_same_type());
  319. CALL_SUBTEST(test_auto_resize());
  320. CALL_SUBTEST(test_compound_assign());
  321. CALL_SUBTEST(test_std_initializers_tensor());
  322. }