cxx11_tensor_chipping.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. template<int DataLayout>
  13. static void test_simple_chip()
  14. {
  15. Tensor<float, 5, DataLayout> tensor(2,3,5,7,11);
  16. tensor.setRandom();
  17. Tensor<float, 4, DataLayout> chip1;
  18. chip1 = tensor.template chip<0>(1);
  19. VERIFY_IS_EQUAL(chip1.dimension(0), 3);
  20. VERIFY_IS_EQUAL(chip1.dimension(1), 5);
  21. VERIFY_IS_EQUAL(chip1.dimension(2), 7);
  22. VERIFY_IS_EQUAL(chip1.dimension(3), 11);
  23. for (int i = 0; i < 3; ++i) {
  24. for (int j = 0; j < 5; ++j) {
  25. for (int k = 0; k < 7; ++k) {
  26. for (int l = 0; l < 11; ++l) {
  27. VERIFY_IS_EQUAL(chip1(i,j,k,l), tensor(1,i,j,k,l));
  28. }
  29. }
  30. }
  31. }
  32. Tensor<float, 4, DataLayout> chip2 = tensor.template chip<1>(1);
  33. VERIFY_IS_EQUAL(chip2.dimension(0), 2);
  34. VERIFY_IS_EQUAL(chip2.dimension(1), 5);
  35. VERIFY_IS_EQUAL(chip2.dimension(2), 7);
  36. VERIFY_IS_EQUAL(chip2.dimension(3), 11);
  37. for (int i = 0; i < 2; ++i) {
  38. for (int j = 0; j < 5; ++j) {
  39. for (int k = 0; k < 7; ++k) {
  40. for (int l = 0; l < 11; ++l) {
  41. VERIFY_IS_EQUAL(chip2(i,j,k,l), tensor(i,1,j,k,l));
  42. }
  43. }
  44. }
  45. }
  46. Tensor<float, 4, DataLayout> chip3 = tensor.template chip<2>(2);
  47. VERIFY_IS_EQUAL(chip3.dimension(0), 2);
  48. VERIFY_IS_EQUAL(chip3.dimension(1), 3);
  49. VERIFY_IS_EQUAL(chip3.dimension(2), 7);
  50. VERIFY_IS_EQUAL(chip3.dimension(3), 11);
  51. for (int i = 0; i < 2; ++i) {
  52. for (int j = 0; j < 3; ++j) {
  53. for (int k = 0; k < 7; ++k) {
  54. for (int l = 0; l < 11; ++l) {
  55. VERIFY_IS_EQUAL(chip3(i,j,k,l), tensor(i,j,2,k,l));
  56. }
  57. }
  58. }
  59. }
  60. Tensor<float, 4, DataLayout> chip4(tensor.template chip<3>(5));
  61. VERIFY_IS_EQUAL(chip4.dimension(0), 2);
  62. VERIFY_IS_EQUAL(chip4.dimension(1), 3);
  63. VERIFY_IS_EQUAL(chip4.dimension(2), 5);
  64. VERIFY_IS_EQUAL(chip4.dimension(3), 11);
  65. for (int i = 0; i < 2; ++i) {
  66. for (int j = 0; j < 3; ++j) {
  67. for (int k = 0; k < 5; ++k) {
  68. for (int l = 0; l < 11; ++l) {
  69. VERIFY_IS_EQUAL(chip4(i,j,k,l), tensor(i,j,k,5,l));
  70. }
  71. }
  72. }
  73. }
  74. Tensor<float, 4, DataLayout> chip5(tensor.template chip<4>(7));
  75. VERIFY_IS_EQUAL(chip5.dimension(0), 2);
  76. VERIFY_IS_EQUAL(chip5.dimension(1), 3);
  77. VERIFY_IS_EQUAL(chip5.dimension(2), 5);
  78. VERIFY_IS_EQUAL(chip5.dimension(3), 7);
  79. for (int i = 0; i < 2; ++i) {
  80. for (int j = 0; j < 3; ++j) {
  81. for (int k = 0; k < 5; ++k) {
  82. for (int l = 0; l < 7; ++l) {
  83. VERIFY_IS_EQUAL(chip5(i,j,k,l), tensor(i,j,k,l,7));
  84. }
  85. }
  86. }
  87. }
  88. }
  89. template<int DataLayout>
  90. static void test_dynamic_chip()
  91. {
  92. Tensor<float, 5, DataLayout> tensor(2,3,5,7,11);
  93. tensor.setRandom();
  94. Tensor<float, 4, DataLayout> chip1;
  95. chip1 = tensor.chip(1, 0);
  96. VERIFY_IS_EQUAL(chip1.dimension(0), 3);
  97. VERIFY_IS_EQUAL(chip1.dimension(1), 5);
  98. VERIFY_IS_EQUAL(chip1.dimension(2), 7);
  99. VERIFY_IS_EQUAL(chip1.dimension(3), 11);
  100. for (int i = 0; i < 3; ++i) {
  101. for (int j = 0; j < 5; ++j) {
  102. for (int k = 0; k < 7; ++k) {
  103. for (int l = 0; l < 11; ++l) {
  104. VERIFY_IS_EQUAL(chip1(i,j,k,l), tensor(1,i,j,k,l));
  105. }
  106. }
  107. }
  108. }
  109. Tensor<float, 4, DataLayout> chip2 = tensor.chip(1, 1);
  110. VERIFY_IS_EQUAL(chip2.dimension(0), 2);
  111. VERIFY_IS_EQUAL(chip2.dimension(1), 5);
  112. VERIFY_IS_EQUAL(chip2.dimension(2), 7);
  113. VERIFY_IS_EQUAL(chip2.dimension(3), 11);
  114. for (int i = 0; i < 2; ++i) {
  115. for (int j = 0; j < 5; ++j) {
  116. for (int k = 0; k < 7; ++k) {
  117. for (int l = 0; l < 11; ++l) {
  118. VERIFY_IS_EQUAL(chip2(i,j,k,l), tensor(i,1,j,k,l));
  119. }
  120. }
  121. }
  122. }
  123. Tensor<float, 4, DataLayout> chip3 = tensor.chip(2, 2);
  124. VERIFY_IS_EQUAL(chip3.dimension(0), 2);
  125. VERIFY_IS_EQUAL(chip3.dimension(1), 3);
  126. VERIFY_IS_EQUAL(chip3.dimension(2), 7);
  127. VERIFY_IS_EQUAL(chip3.dimension(3), 11);
  128. for (int i = 0; i < 2; ++i) {
  129. for (int j = 0; j < 3; ++j) {
  130. for (int k = 0; k < 7; ++k) {
  131. for (int l = 0; l < 11; ++l) {
  132. VERIFY_IS_EQUAL(chip3(i,j,k,l), tensor(i,j,2,k,l));
  133. }
  134. }
  135. }
  136. }
  137. Tensor<float, 4, DataLayout> chip4(tensor.chip(5, 3));
  138. VERIFY_IS_EQUAL(chip4.dimension(0), 2);
  139. VERIFY_IS_EQUAL(chip4.dimension(1), 3);
  140. VERIFY_IS_EQUAL(chip4.dimension(2), 5);
  141. VERIFY_IS_EQUAL(chip4.dimension(3), 11);
  142. for (int i = 0; i < 2; ++i) {
  143. for (int j = 0; j < 3; ++j) {
  144. for (int k = 0; k < 5; ++k) {
  145. for (int l = 0; l < 11; ++l) {
  146. VERIFY_IS_EQUAL(chip4(i,j,k,l), tensor(i,j,k,5,l));
  147. }
  148. }
  149. }
  150. }
  151. Tensor<float, 4, DataLayout> chip5(tensor.chip(7, 4));
  152. VERIFY_IS_EQUAL(chip5.dimension(0), 2);
  153. VERIFY_IS_EQUAL(chip5.dimension(1), 3);
  154. VERIFY_IS_EQUAL(chip5.dimension(2), 5);
  155. VERIFY_IS_EQUAL(chip5.dimension(3), 7);
  156. for (int i = 0; i < 2; ++i) {
  157. for (int j = 0; j < 3; ++j) {
  158. for (int k = 0; k < 5; ++k) {
  159. for (int l = 0; l < 7; ++l) {
  160. VERIFY_IS_EQUAL(chip5(i,j,k,l), tensor(i,j,k,l,7));
  161. }
  162. }
  163. }
  164. }
  165. }
  166. template<int DataLayout>
  167. static void test_chip_in_expr() {
  168. Tensor<float, 5, DataLayout> input1(2,3,5,7,11);
  169. input1.setRandom();
  170. Tensor<float, 4, DataLayout> input2(3,5,7,11);
  171. input2.setRandom();
  172. Tensor<float, 4, DataLayout> result = input1.template chip<0>(0) + input2;
  173. for (int i = 0; i < 3; ++i) {
  174. for (int j = 0; j < 5; ++j) {
  175. for (int k = 0; k < 7; ++k) {
  176. for (int l = 0; l < 11; ++l) {
  177. float expected = input1(0,i,j,k,l) + input2(i,j,k,l);
  178. VERIFY_IS_EQUAL(result(i,j,k,l), expected);
  179. }
  180. }
  181. }
  182. }
  183. Tensor<float, 3, DataLayout> input3(3,7,11);
  184. input3.setRandom();
  185. Tensor<float, 3, DataLayout> result2 = input1.template chip<0>(0).template chip<1>(2) + input3;
  186. for (int i = 0; i < 3; ++i) {
  187. for (int j = 0; j < 7; ++j) {
  188. for (int k = 0; k < 11; ++k) {
  189. float expected = input1(0,i,2,j,k) + input3(i,j,k);
  190. VERIFY_IS_EQUAL(result2(i,j,k), expected);
  191. }
  192. }
  193. }
  194. }
  195. template<int DataLayout>
  196. static void test_chip_as_lvalue()
  197. {
  198. Tensor<float, 5, DataLayout> input1(2,3,5,7,11);
  199. input1.setRandom();
  200. Tensor<float, 4, DataLayout> input2(3,5,7,11);
  201. input2.setRandom();
  202. Tensor<float, 5, DataLayout> tensor = input1;
  203. tensor.template chip<0>(1) = input2;
  204. for (int i = 0; i < 2; ++i) {
  205. for (int j = 0; j < 3; ++j) {
  206. for (int k = 0; k < 5; ++k) {
  207. for (int l = 0; l < 7; ++l) {
  208. for (int m = 0; m < 11; ++m) {
  209. if (i != 1) {
  210. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
  211. } else {
  212. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input2(j,k,l,m));
  213. }
  214. }
  215. }
  216. }
  217. }
  218. }
  219. Tensor<float, 4, DataLayout> input3(2,5,7,11);
  220. input3.setRandom();
  221. tensor = input1;
  222. tensor.template chip<1>(1) = input3;
  223. for (int i = 0; i < 2; ++i) {
  224. for (int j = 0; j < 3; ++j) {
  225. for (int k = 0; k < 5; ++k) {
  226. for (int l = 0; l < 7; ++l) {
  227. for (int m = 0; m < 11; ++m) {
  228. if (j != 1) {
  229. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
  230. } else {
  231. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input3(i,k,l,m));
  232. }
  233. }
  234. }
  235. }
  236. }
  237. }
  238. Tensor<float, 4, DataLayout> input4(2,3,7,11);
  239. input4.setRandom();
  240. tensor = input1;
  241. tensor.template chip<2>(3) = input4;
  242. for (int i = 0; i < 2; ++i) {
  243. for (int j = 0; j < 3; ++j) {
  244. for (int k = 0; k < 5; ++k) {
  245. for (int l = 0; l < 7; ++l) {
  246. for (int m = 0; m < 11; ++m) {
  247. if (k != 3) {
  248. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
  249. } else {
  250. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input4(i,j,l,m));
  251. }
  252. }
  253. }
  254. }
  255. }
  256. }
  257. Tensor<float, 4, DataLayout> input5(2,3,5,11);
  258. input5.setRandom();
  259. tensor = input1;
  260. tensor.template chip<3>(4) = input5;
  261. for (int i = 0; i < 2; ++i) {
  262. for (int j = 0; j < 3; ++j) {
  263. for (int k = 0; k < 5; ++k) {
  264. for (int l = 0; l < 7; ++l) {
  265. for (int m = 0; m < 11; ++m) {
  266. if (l != 4) {
  267. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
  268. } else {
  269. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input5(i,j,k,m));
  270. }
  271. }
  272. }
  273. }
  274. }
  275. }
  276. Tensor<float, 4, DataLayout> input6(2,3,5,7);
  277. input6.setRandom();
  278. tensor = input1;
  279. tensor.template chip<4>(5) = input6;
  280. for (int i = 0; i < 2; ++i) {
  281. for (int j = 0; j < 3; ++j) {
  282. for (int k = 0; k < 5; ++k) {
  283. for (int l = 0; l < 7; ++l) {
  284. for (int m = 0; m < 11; ++m) {
  285. if (m != 5) {
  286. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
  287. } else {
  288. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input6(i,j,k,l));
  289. }
  290. }
  291. }
  292. }
  293. }
  294. }
  295. Tensor<float, 5, DataLayout> input7(2,3,5,7,11);
  296. input7.setRandom();
  297. tensor = input1;
  298. tensor.chip(0, 0) = input7.chip(0, 0);
  299. for (int i = 0; i < 2; ++i) {
  300. for (int j = 0; j < 3; ++j) {
  301. for (int k = 0; k < 5; ++k) {
  302. for (int l = 0; l < 7; ++l) {
  303. for (int m = 0; m < 11; ++m) {
  304. if (i != 0) {
  305. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
  306. } else {
  307. VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input7(i,j,k,l,m));
  308. }
  309. }
  310. }
  311. }
  312. }
  313. }
  314. }
  315. static void test_chip_raw_data_col_major()
  316. {
  317. Tensor<float, 5, ColMajor> tensor(2,3,5,7,11);
  318. tensor.setRandom();
  319. typedef TensorEvaluator<decltype(tensor.chip<4>(3)), DefaultDevice> Evaluator4;
  320. auto chip = Evaluator4(tensor.chip<4>(3), DefaultDevice());
  321. for (int i = 0; i < 2; ++i) {
  322. for (int j = 0; j < 3; ++j) {
  323. for (int k = 0; k < 5; ++k) {
  324. for (int l = 0; l < 7; ++l) {
  325. int chip_index = i + 2 * (j + 3 * (k + 5 * l));
  326. VERIFY_IS_EQUAL(chip.data()[chip_index], tensor(i,j,k,l,3));
  327. }
  328. }
  329. }
  330. }
  331. typedef TensorEvaluator<decltype(tensor.chip<0>(0)), DefaultDevice> Evaluator0;
  332. auto chip0 = Evaluator0(tensor.chip<0>(0), DefaultDevice());
  333. VERIFY_IS_EQUAL(chip0.data(), static_cast<float*>(0));
  334. typedef TensorEvaluator<decltype(tensor.chip<1>(0)), DefaultDevice> Evaluator1;
  335. auto chip1 = Evaluator1(tensor.chip<1>(0), DefaultDevice());
  336. VERIFY_IS_EQUAL(chip1.data(), static_cast<float*>(0));
  337. typedef TensorEvaluator<decltype(tensor.chip<2>(0)), DefaultDevice> Evaluator2;
  338. auto chip2 = Evaluator2(tensor.chip<2>(0), DefaultDevice());
  339. VERIFY_IS_EQUAL(chip2.data(), static_cast<float*>(0));
  340. typedef TensorEvaluator<decltype(tensor.chip<3>(0)), DefaultDevice> Evaluator3;
  341. auto chip3 = Evaluator3(tensor.chip<3>(0), DefaultDevice());
  342. VERIFY_IS_EQUAL(chip3.data(), static_cast<float*>(0));
  343. }
  344. static void test_chip_raw_data_row_major()
  345. {
  346. Tensor<float, 5, RowMajor> tensor(11,7,5,3,2);
  347. tensor.setRandom();
  348. typedef TensorEvaluator<decltype(tensor.chip<0>(3)), DefaultDevice> Evaluator0;
  349. auto chip = Evaluator0(tensor.chip<0>(3), DefaultDevice());
  350. for (int i = 0; i < 7; ++i) {
  351. for (int j = 0; j < 5; ++j) {
  352. for (int k = 0; k < 3; ++k) {
  353. for (int l = 0; l < 2; ++l) {
  354. int chip_index = l + 2 * (k + 3 * (j + 5 * i));
  355. VERIFY_IS_EQUAL(chip.data()[chip_index], tensor(3,i,j,k,l));
  356. }
  357. }
  358. }
  359. }
  360. typedef TensorEvaluator<decltype(tensor.chip<1>(0)), DefaultDevice> Evaluator1;
  361. auto chip1 = Evaluator1(tensor.chip<1>(0), DefaultDevice());
  362. VERIFY_IS_EQUAL(chip1.data(), static_cast<float*>(0));
  363. typedef TensorEvaluator<decltype(tensor.chip<2>(0)), DefaultDevice> Evaluator2;
  364. auto chip2 = Evaluator2(tensor.chip<2>(0), DefaultDevice());
  365. VERIFY_IS_EQUAL(chip2.data(), static_cast<float*>(0));
  366. typedef TensorEvaluator<decltype(tensor.chip<3>(0)), DefaultDevice> Evaluator3;
  367. auto chip3 = Evaluator3(tensor.chip<3>(0), DefaultDevice());
  368. VERIFY_IS_EQUAL(chip3.data(), static_cast<float*>(0));
  369. typedef TensorEvaluator<decltype(tensor.chip<4>(0)), DefaultDevice> Evaluator4;
  370. auto chip4 = Evaluator4(tensor.chip<4>(0), DefaultDevice());
  371. VERIFY_IS_EQUAL(chip4.data(), static_cast<float*>(0));
  372. }
  373. EIGEN_DECLARE_TEST(cxx11_tensor_chipping)
  374. {
  375. CALL_SUBTEST(test_simple_chip<ColMajor>());
  376. CALL_SUBTEST(test_simple_chip<RowMajor>());
  377. CALL_SUBTEST(test_dynamic_chip<ColMajor>());
  378. CALL_SUBTEST(test_dynamic_chip<RowMajor>());
  379. CALL_SUBTEST(test_chip_in_expr<ColMajor>());
  380. CALL_SUBTEST(test_chip_in_expr<RowMajor>());
  381. CALL_SUBTEST(test_chip_as_lvalue<ColMajor>());
  382. CALL_SUBTEST(test_chip_as_lvalue<RowMajor>());
  383. CALL_SUBTEST(test_chip_raw_data_col_major());
  384. CALL_SUBTEST(test_chip_raw_data_row_major());
  385. }