cxx11_tensor_shuffling.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. using Eigen::array;
  13. template <int DataLayout>
  14. static void test_simple_shuffling()
  15. {
  16. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  17. tensor.setRandom();
  18. array<ptrdiff_t, 4> shuffles;
  19. shuffles[0] = 0;
  20. shuffles[1] = 1;
  21. shuffles[2] = 2;
  22. shuffles[3] = 3;
  23. Tensor<float, 4, DataLayout> no_shuffle;
  24. no_shuffle = tensor.shuffle(shuffles);
  25. VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
  26. VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
  27. VERIFY_IS_EQUAL(no_shuffle.dimension(2), 5);
  28. VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
  29. for (int i = 0; i < 2; ++i) {
  30. for (int j = 0; j < 3; ++j) {
  31. for (int k = 0; k < 5; ++k) {
  32. for (int l = 0; l < 7; ++l) {
  33. VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
  34. }
  35. }
  36. }
  37. }
  38. shuffles[0] = 2;
  39. shuffles[1] = 3;
  40. shuffles[2] = 1;
  41. shuffles[3] = 0;
  42. Tensor<float, 4, DataLayout> shuffle;
  43. shuffle = tensor.shuffle(shuffles);
  44. VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
  45. VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
  46. VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
  47. VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
  48. for (int i = 0; i < 2; ++i) {
  49. for (int j = 0; j < 3; ++j) {
  50. for (int k = 0; k < 5; ++k) {
  51. for (int l = 0; l < 7; ++l) {
  52. VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
  53. }
  54. }
  55. }
  56. }
  57. }
  58. template <int DataLayout>
  59. static void test_expr_shuffling()
  60. {
  61. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  62. tensor.setRandom();
  63. array<ptrdiff_t, 4> shuffles;
  64. shuffles[0] = 2;
  65. shuffles[1] = 3;
  66. shuffles[2] = 1;
  67. shuffles[3] = 0;
  68. Tensor<float, 4, DataLayout> expected;
  69. expected = tensor.shuffle(shuffles);
  70. Tensor<float, 4, DataLayout> result(5, 7, 3, 2);
  71. array<ptrdiff_t, 4> src_slice_dim{{2, 3, 1, 7}};
  72. array<ptrdiff_t, 4> src_slice_start{{0, 0, 0, 0}};
  73. array<ptrdiff_t, 4> dst_slice_dim{{1, 7, 3, 2}};
  74. array<ptrdiff_t, 4> dst_slice_start{{0, 0, 0, 0}};
  75. for (int i = 0; i < 5; ++i) {
  76. result.slice(dst_slice_start, dst_slice_dim) =
  77. tensor.slice(src_slice_start, src_slice_dim).shuffle(shuffles);
  78. src_slice_start[2] += 1;
  79. dst_slice_start[0] += 1;
  80. }
  81. VERIFY_IS_EQUAL(result.dimension(0), 5);
  82. VERIFY_IS_EQUAL(result.dimension(1), 7);
  83. VERIFY_IS_EQUAL(result.dimension(2), 3);
  84. VERIFY_IS_EQUAL(result.dimension(3), 2);
  85. for (int i = 0; i < expected.dimension(0); ++i) {
  86. for (int j = 0; j < expected.dimension(1); ++j) {
  87. for (int k = 0; k < expected.dimension(2); ++k) {
  88. for (int l = 0; l < expected.dimension(3); ++l) {
  89. VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
  90. }
  91. }
  92. }
  93. }
  94. dst_slice_start[0] = 0;
  95. result.setRandom();
  96. for (int i = 0; i < 5; ++i) {
  97. result.slice(dst_slice_start, dst_slice_dim) =
  98. tensor.shuffle(shuffles).slice(dst_slice_start, dst_slice_dim);
  99. dst_slice_start[0] += 1;
  100. }
  101. for (int i = 0; i < expected.dimension(0); ++i) {
  102. for (int j = 0; j < expected.dimension(1); ++j) {
  103. for (int k = 0; k < expected.dimension(2); ++k) {
  104. for (int l = 0; l < expected.dimension(3); ++l) {
  105. VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
  106. }
  107. }
  108. }
  109. }
  110. }
  111. template <int DataLayout>
  112. static void test_shuffling_as_value()
  113. {
  114. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  115. tensor.setRandom();
  116. array<ptrdiff_t, 4> shuffles;
  117. shuffles[2] = 0;
  118. shuffles[3] = 1;
  119. shuffles[1] = 2;
  120. shuffles[0] = 3;
  121. Tensor<float, 4, DataLayout> shuffle(5,7,3,2);
  122. shuffle.shuffle(shuffles) = tensor;
  123. VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
  124. VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
  125. VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
  126. VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
  127. for (int i = 0; i < 2; ++i) {
  128. for (int j = 0; j < 3; ++j) {
  129. for (int k = 0; k < 5; ++k) {
  130. for (int l = 0; l < 7; ++l) {
  131. VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
  132. }
  133. }
  134. }
  135. }
  136. array<ptrdiff_t, 4> no_shuffle;
  137. no_shuffle[0] = 0;
  138. no_shuffle[1] = 1;
  139. no_shuffle[2] = 2;
  140. no_shuffle[3] = 3;
  141. Tensor<float, 4, DataLayout> shuffle2(5,7,3,2);
  142. shuffle2.shuffle(shuffles) = tensor.shuffle(no_shuffle);
  143. for (int i = 0; i < 5; ++i) {
  144. for (int j = 0; j < 7; ++j) {
  145. for (int k = 0; k < 3; ++k) {
  146. for (int l = 0; l < 2; ++l) {
  147. VERIFY_IS_EQUAL(shuffle2(i,j,k,l), shuffle(i,j,k,l));
  148. }
  149. }
  150. }
  151. }
  152. }
  153. template <int DataLayout>
  154. static void test_shuffle_unshuffle()
  155. {
  156. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  157. tensor.setRandom();
  158. // Choose a random permutation.
  159. array<ptrdiff_t, 4> shuffles;
  160. for (int i = 0; i < 4; ++i) {
  161. shuffles[i] = i;
  162. }
  163. array<ptrdiff_t, 4> shuffles_inverse;
  164. for (int i = 0; i < 4; ++i) {
  165. const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
  166. shuffles_inverse[shuffles[index]] = i;
  167. std::swap(shuffles[i], shuffles[index]);
  168. }
  169. Tensor<float, 4, DataLayout> shuffle;
  170. shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
  171. VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
  172. VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
  173. VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
  174. VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
  175. for (int i = 0; i < 2; ++i) {
  176. for (int j = 0; j < 3; ++j) {
  177. for (int k = 0; k < 5; ++k) {
  178. for (int l = 0; l < 7; ++l) {
  179. VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
  180. }
  181. }
  182. }
  183. }
  184. }
  185. template <int DataLayout>
  186. static void test_empty_shuffling()
  187. {
  188. Tensor<float, 4, DataLayout> tensor(2,3,0,7);
  189. tensor.setRandom();
  190. array<ptrdiff_t, 4> shuffles;
  191. shuffles[0] = 0;
  192. shuffles[1] = 1;
  193. shuffles[2] = 2;
  194. shuffles[3] = 3;
  195. Tensor<float, 4, DataLayout> no_shuffle;
  196. no_shuffle = tensor.shuffle(shuffles);
  197. VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
  198. VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
  199. VERIFY_IS_EQUAL(no_shuffle.dimension(2), 0);
  200. VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
  201. for (int i = 0; i < 2; ++i) {
  202. for (int j = 0; j < 3; ++j) {
  203. for (int k = 0; k < 0; ++k) {
  204. for (int l = 0; l < 7; ++l) {
  205. VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
  206. }
  207. }
  208. }
  209. }
  210. shuffles[0] = 2;
  211. shuffles[1] = 3;
  212. shuffles[2] = 1;
  213. shuffles[3] = 0;
  214. Tensor<float, 4, DataLayout> shuffle;
  215. shuffle = tensor.shuffle(shuffles);
  216. VERIFY_IS_EQUAL(shuffle.dimension(0), 0);
  217. VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
  218. VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
  219. VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
  220. for (int i = 0; i < 2; ++i) {
  221. for (int j = 0; j < 3; ++j) {
  222. for (int k = 0; k < 0; ++k) {
  223. for (int l = 0; l < 7; ++l) {
  224. VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
  225. }
  226. }
  227. }
  228. }
  229. }
  230. EIGEN_DECLARE_TEST(cxx11_tensor_shuffling)
  231. {
  232. CALL_SUBTEST(test_simple_shuffling<ColMajor>());
  233. CALL_SUBTEST(test_simple_shuffling<RowMajor>());
  234. CALL_SUBTEST(test_expr_shuffling<ColMajor>());
  235. CALL_SUBTEST(test_expr_shuffling<RowMajor>());
  236. CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
  237. CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
  238. CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
  239. CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
  240. CALL_SUBTEST(test_empty_shuffling<ColMajor>());
  241. CALL_SUBTEST(test_empty_shuffling<RowMajor>());
  242. }