cxx11_tensor_broadcasting.cpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. template <int DataLayout>
  13. static void test_simple_broadcasting()
  14. {
  15. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  16. tensor.setRandom();
  17. array<ptrdiff_t, 4> broadcasts;
  18. broadcasts[0] = 1;
  19. broadcasts[1] = 1;
  20. broadcasts[2] = 1;
  21. broadcasts[3] = 1;
  22. Tensor<float, 4, DataLayout> no_broadcast;
  23. no_broadcast = tensor.broadcast(broadcasts);
  24. VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2);
  25. VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3);
  26. VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5);
  27. VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7);
  28. for (int i = 0; i < 2; ++i) {
  29. for (int j = 0; j < 3; ++j) {
  30. for (int k = 0; k < 5; ++k) {
  31. for (int l = 0; l < 7; ++l) {
  32. VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
  33. }
  34. }
  35. }
  36. }
  37. broadcasts[0] = 2;
  38. broadcasts[1] = 3;
  39. broadcasts[2] = 1;
  40. broadcasts[3] = 4;
  41. Tensor<float, 4, DataLayout> broadcast;
  42. broadcast = tensor.broadcast(broadcasts);
  43. VERIFY_IS_EQUAL(broadcast.dimension(0), 4);
  44. VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
  45. VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
  46. VERIFY_IS_EQUAL(broadcast.dimension(3), 28);
  47. for (int i = 0; i < 4; ++i) {
  48. for (int j = 0; j < 9; ++j) {
  49. for (int k = 0; k < 5; ++k) {
  50. for (int l = 0; l < 28; ++l) {
  51. VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
  52. }
  53. }
  54. }
  55. }
  56. }
  57. template <int DataLayout>
  58. static void test_vectorized_broadcasting()
  59. {
  60. Tensor<float, 3, DataLayout> tensor(8,3,5);
  61. tensor.setRandom();
  62. array<ptrdiff_t, 3> broadcasts;
  63. broadcasts[0] = 2;
  64. broadcasts[1] = 3;
  65. broadcasts[2] = 4;
  66. Tensor<float, 3, DataLayout> broadcast;
  67. broadcast = tensor.broadcast(broadcasts);
  68. VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
  69. VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
  70. VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
  71. for (int i = 0; i < 16; ++i) {
  72. for (int j = 0; j < 9; ++j) {
  73. for (int k = 0; k < 20; ++k) {
  74. VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
  75. }
  76. }
  77. }
  78. #if EIGEN_HAS_VARIADIC_TEMPLATES
  79. tensor.resize(11,3,5);
  80. #else
  81. array<Index, 3> new_dims;
  82. new_dims[0] = 11;
  83. new_dims[1] = 3;
  84. new_dims[2] = 5;
  85. tensor.resize(new_dims);
  86. #endif
  87. tensor.setRandom();
  88. broadcast = tensor.broadcast(broadcasts);
  89. VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
  90. VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
  91. VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
  92. for (int i = 0; i < 22; ++i) {
  93. for (int j = 0; j < 9; ++j) {
  94. for (int k = 0; k < 20; ++k) {
  95. VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
  96. }
  97. }
  98. }
  99. }
  100. template <int DataLayout>
  101. static void test_static_broadcasting()
  102. {
  103. Tensor<float, 3, DataLayout> tensor(8,3,5);
  104. tensor.setRandom();
  105. #if defined(EIGEN_HAS_INDEX_LIST)
  106. Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
  107. #else
  108. Eigen::array<int, 3> broadcasts;
  109. broadcasts[0] = 2;
  110. broadcasts[1] = 3;
  111. broadcasts[2] = 4;
  112. #endif
  113. Tensor<float, 3, DataLayout> broadcast;
  114. broadcast = tensor.broadcast(broadcasts);
  115. VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
  116. VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
  117. VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
  118. for (int i = 0; i < 16; ++i) {
  119. for (int j = 0; j < 9; ++j) {
  120. for (int k = 0; k < 20; ++k) {
  121. VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
  122. }
  123. }
  124. }
  125. #if EIGEN_HAS_VARIADIC_TEMPLATES
  126. tensor.resize(11,3,5);
  127. #else
  128. array<Index, 3> new_dims;
  129. new_dims[0] = 11;
  130. new_dims[1] = 3;
  131. new_dims[2] = 5;
  132. tensor.resize(new_dims);
  133. #endif
  134. tensor.setRandom();
  135. broadcast = tensor.broadcast(broadcasts);
  136. VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
  137. VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
  138. VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
  139. for (int i = 0; i < 22; ++i) {
  140. for (int j = 0; j < 9; ++j) {
  141. for (int k = 0; k < 20; ++k) {
  142. VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
  143. }
  144. }
  145. }
  146. }
  147. template <int DataLayout>
  148. static void test_fixed_size_broadcasting()
  149. {
  150. // Need to add a [] operator to the Size class for this to work
  151. #if 0
  152. Tensor<float, 1, DataLayout> t1(10);
  153. t1.setRandom();
  154. TensorFixedSize<float, Sizes<1>, DataLayout> t2;
  155. t2 = t2.constant(20.0f);
  156. Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}});
  157. for (int i = 0; i < 10; ++i) {
  158. VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
  159. }
  160. TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
  161. Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}});
  162. for (int i = 0; i < 10; ++i) {
  163. VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
  164. }
  165. #endif
  166. }
  167. template <int DataLayout>
  168. static void test_simple_broadcasting_one_by_n()
  169. {
  170. Tensor<float, 4, DataLayout> tensor(1,13,5,7);
  171. tensor.setRandom();
  172. array<ptrdiff_t, 4> broadcasts;
  173. broadcasts[0] = 9;
  174. broadcasts[1] = 1;
  175. broadcasts[2] = 1;
  176. broadcasts[3] = 1;
  177. Tensor<float, 4, DataLayout> broadcast;
  178. broadcast = tensor.broadcast(broadcasts);
  179. VERIFY_IS_EQUAL(broadcast.dimension(0), 9);
  180. VERIFY_IS_EQUAL(broadcast.dimension(1), 13);
  181. VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
  182. VERIFY_IS_EQUAL(broadcast.dimension(3), 7);
  183. for (int i = 0; i < 9; ++i) {
  184. for (int j = 0; j < 13; ++j) {
  185. for (int k = 0; k < 5; ++k) {
  186. for (int l = 0; l < 7; ++l) {
  187. VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l));
  188. }
  189. }
  190. }
  191. }
  192. }
  193. template <int DataLayout>
  194. static void test_simple_broadcasting_n_by_one()
  195. {
  196. Tensor<float, 4, DataLayout> tensor(7,3,5,1);
  197. tensor.setRandom();
  198. array<ptrdiff_t, 4> broadcasts;
  199. broadcasts[0] = 1;
  200. broadcasts[1] = 1;
  201. broadcasts[2] = 1;
  202. broadcasts[3] = 19;
  203. Tensor<float, 4, DataLayout> broadcast;
  204. broadcast = tensor.broadcast(broadcasts);
  205. VERIFY_IS_EQUAL(broadcast.dimension(0), 7);
  206. VERIFY_IS_EQUAL(broadcast.dimension(1), 3);
  207. VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
  208. VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
  209. for (int i = 0; i < 7; ++i) {
  210. for (int j = 0; j < 3; ++j) {
  211. for (int k = 0; k < 5; ++k) {
  212. for (int l = 0; l < 19; ++l) {
  213. VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l));
  214. }
  215. }
  216. }
  217. }
  218. }
  219. template <int DataLayout>
  220. static void test_simple_broadcasting_one_by_n_by_one_1d()
  221. {
  222. Tensor<float, 3, DataLayout> tensor(1,7,1);
  223. tensor.setRandom();
  224. array<ptrdiff_t, 3> broadcasts;
  225. broadcasts[0] = 5;
  226. broadcasts[1] = 1;
  227. broadcasts[2] = 13;
  228. Tensor<float, 3, DataLayout> broadcasted;
  229. broadcasted = tensor.broadcast(broadcasts);
  230. VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
  231. VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
  232. VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
  233. for (int i = 0; i < 5; ++i) {
  234. for (int j = 0; j < 7; ++j) {
  235. for (int k = 0; k < 13; ++k) {
  236. VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
  237. }
  238. }
  239. }
  240. }
  241. template <int DataLayout>
  242. static void test_simple_broadcasting_one_by_n_by_one_2d()
  243. {
  244. Tensor<float, 4, DataLayout> tensor(1,7,13,1);
  245. tensor.setRandom();
  246. array<ptrdiff_t, 4> broadcasts;
  247. broadcasts[0] = 5;
  248. broadcasts[1] = 1;
  249. broadcasts[2] = 1;
  250. broadcasts[3] = 19;
  251. Tensor<float, 4, DataLayout> broadcast;
  252. broadcast = tensor.broadcast(broadcasts);
  253. VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
  254. VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
  255. VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
  256. VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
  257. for (int i = 0; i < 5; ++i) {
  258. for (int j = 0; j < 7; ++j) {
  259. for (int k = 0; k < 13; ++k) {
  260. for (int l = 0; l < 19; ++l) {
  261. VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
  262. }
  263. }
  264. }
  265. }
  266. }
  267. EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting)
  268. {
  269. CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
  270. CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
  271. CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
  272. CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
  273. CALL_SUBTEST(test_static_broadcasting<ColMajor>());
  274. CALL_SUBTEST(test_static_broadcasting<RowMajor>());
  275. CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
  276. CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
  277. CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>());
  278. CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
  279. CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
  280. CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
  281. CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
  282. CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
  283. CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
  284. CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
  285. }