cxx11_tensor_map.cpp 7.8 KB


  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. using Eigen::RowMajor;
  13. static void test_0d()
  14. {
  15. Tensor<int, 0> scalar1;
  16. Tensor<int, 0, RowMajor> scalar2;
  17. TensorMap<const Tensor<int, 0> > scalar3(scalar1.data());
  18. TensorMap<const Tensor<int, 0, RowMajor> > scalar4(scalar2.data());
  19. scalar1() = 7;
  20. scalar2() = 13;
  21. VERIFY_IS_EQUAL(scalar1.rank(), 0);
  22. VERIFY_IS_EQUAL(scalar1.size(), 1);
  23. VERIFY_IS_EQUAL(scalar3(), 7);
  24. VERIFY_IS_EQUAL(scalar4(), 13);
  25. }
  26. static void test_1d()
  27. {
  28. Tensor<int, 1> vec1(6);
  29. Tensor<int, 1, RowMajor> vec2(6);
  30. TensorMap<const Tensor<int, 1> > vec3(vec1.data(), 6);
  31. TensorMap<const Tensor<int, 1, RowMajor> > vec4(vec2.data(), 6);
  32. vec1(0) = 4; vec2(0) = 0;
  33. vec1(1) = 8; vec2(1) = 1;
  34. vec1(2) = 15; vec2(2) = 2;
  35. vec1(3) = 16; vec2(3) = 3;
  36. vec1(4) = 23; vec2(4) = 4;
  37. vec1(5) = 42; vec2(5) = 5;
  38. VERIFY_IS_EQUAL(vec1.rank(), 1);
  39. VERIFY_IS_EQUAL(vec1.size(), 6);
  40. VERIFY_IS_EQUAL(vec1.dimension(0), 6);
  41. VERIFY_IS_EQUAL(vec3(0), 4);
  42. VERIFY_IS_EQUAL(vec3(1), 8);
  43. VERIFY_IS_EQUAL(vec3(2), 15);
  44. VERIFY_IS_EQUAL(vec3(3), 16);
  45. VERIFY_IS_EQUAL(vec3(4), 23);
  46. VERIFY_IS_EQUAL(vec3(5), 42);
  47. VERIFY_IS_EQUAL(vec4(0), 0);
  48. VERIFY_IS_EQUAL(vec4(1), 1);
  49. VERIFY_IS_EQUAL(vec4(2), 2);
  50. VERIFY_IS_EQUAL(vec4(3), 3);
  51. VERIFY_IS_EQUAL(vec4(4), 4);
  52. VERIFY_IS_EQUAL(vec4(5), 5);
  53. }
  54. static void test_2d()
  55. {
  56. Tensor<int, 2> mat1(2,3);
  57. Tensor<int, 2, RowMajor> mat2(2,3);
  58. mat1(0,0) = 0;
  59. mat1(0,1) = 1;
  60. mat1(0,2) = 2;
  61. mat1(1,0) = 3;
  62. mat1(1,1) = 4;
  63. mat1(1,2) = 5;
  64. mat2(0,0) = 0;
  65. mat2(0,1) = 1;
  66. mat2(0,2) = 2;
  67. mat2(1,0) = 3;
  68. mat2(1,1) = 4;
  69. mat2(1,2) = 5;
  70. TensorMap<const Tensor<int, 2> > mat3(mat1.data(), 2, 3);
  71. TensorMap<const Tensor<int, 2, RowMajor> > mat4(mat2.data(), 2, 3);
  72. VERIFY_IS_EQUAL(mat3.rank(), 2);
  73. VERIFY_IS_EQUAL(mat3.size(), 6);
  74. VERIFY_IS_EQUAL(mat3.dimension(0), 2);
  75. VERIFY_IS_EQUAL(mat3.dimension(1), 3);
  76. VERIFY_IS_EQUAL(mat4.rank(), 2);
  77. VERIFY_IS_EQUAL(mat4.size(), 6);
  78. VERIFY_IS_EQUAL(mat4.dimension(0), 2);
  79. VERIFY_IS_EQUAL(mat4.dimension(1), 3);
  80. VERIFY_IS_EQUAL(mat3(0,0), 0);
  81. VERIFY_IS_EQUAL(mat3(0,1), 1);
  82. VERIFY_IS_EQUAL(mat3(0,2), 2);
  83. VERIFY_IS_EQUAL(mat3(1,0), 3);
  84. VERIFY_IS_EQUAL(mat3(1,1), 4);
  85. VERIFY_IS_EQUAL(mat3(1,2), 5);
  86. VERIFY_IS_EQUAL(mat4(0,0), 0);
  87. VERIFY_IS_EQUAL(mat4(0,1), 1);
  88. VERIFY_IS_EQUAL(mat4(0,2), 2);
  89. VERIFY_IS_EQUAL(mat4(1,0), 3);
  90. VERIFY_IS_EQUAL(mat4(1,1), 4);
  91. VERIFY_IS_EQUAL(mat4(1,2), 5);
  92. }
  93. static void test_3d()
  94. {
  95. Tensor<int, 3> mat1(2,3,7);
  96. Tensor<int, 3, RowMajor> mat2(2,3,7);
  97. int val = 0;
  98. for (int i = 0; i < 2; ++i) {
  99. for (int j = 0; j < 3; ++j) {
  100. for (int k = 0; k < 7; ++k) {
  101. mat1(i,j,k) = val;
  102. mat2(i,j,k) = val;
  103. val++;
  104. }
  105. }
  106. }
  107. TensorMap<const Tensor<int, 3> > mat3(mat1.data(), 2, 3, 7);
  108. TensorMap<const Tensor<int, 3, RowMajor> > mat4(mat2.data(), 2, 3, 7);
  109. VERIFY_IS_EQUAL(mat3.rank(), 3);
  110. VERIFY_IS_EQUAL(mat3.size(), 2*3*7);
  111. VERIFY_IS_EQUAL(mat3.dimension(0), 2);
  112. VERIFY_IS_EQUAL(mat3.dimension(1), 3);
  113. VERIFY_IS_EQUAL(mat3.dimension(2), 7);
  114. VERIFY_IS_EQUAL(mat4.rank(), 3);
  115. VERIFY_IS_EQUAL(mat4.size(), 2*3*7);
  116. VERIFY_IS_EQUAL(mat4.dimension(0), 2);
  117. VERIFY_IS_EQUAL(mat4.dimension(1), 3);
  118. VERIFY_IS_EQUAL(mat4.dimension(2), 7);
  119. val = 0;
  120. for (int i = 0; i < 2; ++i) {
  121. for (int j = 0; j < 3; ++j) {
  122. for (int k = 0; k < 7; ++k) {
  123. VERIFY_IS_EQUAL(mat3(i,j,k), val);
  124. VERIFY_IS_EQUAL(mat4(i,j,k), val);
  125. val++;
  126. }
  127. }
  128. }
  129. }
  130. static void test_from_tensor()
  131. {
  132. Tensor<int, 3> mat1(2,3,7);
  133. Tensor<int, 3, RowMajor> mat2(2,3,7);
  134. int val = 0;
  135. for (int i = 0; i < 2; ++i) {
  136. for (int j = 0; j < 3; ++j) {
  137. for (int k = 0; k < 7; ++k) {
  138. mat1(i,j,k) = val;
  139. mat2(i,j,k) = val;
  140. val++;
  141. }
  142. }
  143. }
  144. TensorMap<Tensor<int, 3> > mat3(mat1);
  145. TensorMap<Tensor<int, 3, RowMajor> > mat4(mat2);
  146. VERIFY_IS_EQUAL(mat3.rank(), 3);
  147. VERIFY_IS_EQUAL(mat3.size(), 2*3*7);
  148. VERIFY_IS_EQUAL(mat3.dimension(0), 2);
  149. VERIFY_IS_EQUAL(mat3.dimension(1), 3);
  150. VERIFY_IS_EQUAL(mat3.dimension(2), 7);
  151. VERIFY_IS_EQUAL(mat4.rank(), 3);
  152. VERIFY_IS_EQUAL(mat4.size(), 2*3*7);
  153. VERIFY_IS_EQUAL(mat4.dimension(0), 2);
  154. VERIFY_IS_EQUAL(mat4.dimension(1), 3);
  155. VERIFY_IS_EQUAL(mat4.dimension(2), 7);
  156. val = 0;
  157. for (int i = 0; i < 2; ++i) {
  158. for (int j = 0; j < 3; ++j) {
  159. for (int k = 0; k < 7; ++k) {
  160. VERIFY_IS_EQUAL(mat3(i,j,k), val);
  161. VERIFY_IS_EQUAL(mat4(i,j,k), val);
  162. val++;
  163. }
  164. }
  165. }
  166. TensorFixedSize<int, Sizes<2,3,7> > mat5;
  167. val = 0;
  168. for (int i = 0; i < 2; ++i) {
  169. for (int j = 0; j < 3; ++j) {
  170. for (int k = 0; k < 7; ++k) {
  171. array<ptrdiff_t, 3> coords;
  172. coords[0] = i;
  173. coords[1] = j;
  174. coords[2] = k;
  175. mat5(coords) = val;
  176. val++;
  177. }
  178. }
  179. }
  180. TensorMap<TensorFixedSize<int, Sizes<2,3,7> > > mat6(mat5);
  181. VERIFY_IS_EQUAL(mat6.rank(), 3);
  182. VERIFY_IS_EQUAL(mat6.size(), 2*3*7);
  183. VERIFY_IS_EQUAL(mat6.dimension(0), 2);
  184. VERIFY_IS_EQUAL(mat6.dimension(1), 3);
  185. VERIFY_IS_EQUAL(mat6.dimension(2), 7);
  186. val = 0;
  187. for (int i = 0; i < 2; ++i) {
  188. for (int j = 0; j < 3; ++j) {
  189. for (int k = 0; k < 7; ++k) {
  190. VERIFY_IS_EQUAL(mat6(i,j,k), val);
  191. val++;
  192. }
  193. }
  194. }
  195. }
  196. static int f(const TensorMap<Tensor<int, 3> >& tensor) {
  197. // Size<0> empty;
  198. EIGEN_STATIC_ASSERT((internal::array_size<Sizes<> >::value == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
  199. EIGEN_STATIC_ASSERT((internal::array_size<DSizes<int, 0> >::value == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
  200. Tensor<int, 0> result = tensor.sum();
  201. return result();
  202. }
  203. static void test_casting()
  204. {
  205. Tensor<int, 3> tensor(2,3,7);
  206. int val = 0;
  207. for (int i = 0; i < 2; ++i) {
  208. for (int j = 0; j < 3; ++j) {
  209. for (int k = 0; k < 7; ++k) {
  210. tensor(i,j,k) = val;
  211. val++;
  212. }
  213. }
  214. }
  215. TensorMap<Tensor<int, 3> > map(tensor);
  216. int sum1 = f(map);
  217. int sum2 = f(tensor);
  218. VERIFY_IS_EQUAL(sum1, sum2);
  219. VERIFY_IS_EQUAL(sum1, 861);
  220. }
  221. template<typename T>
  222. static const T& add_const(T& value) {
  223. return value;
  224. }
  225. static void test_0d_const_tensor()
  226. {
  227. Tensor<int, 0> scalar1;
  228. Tensor<int, 0, RowMajor> scalar2;
  229. TensorMap<const Tensor<int, 0> > scalar3(add_const(scalar1).data());
  230. TensorMap<const Tensor<int, 0, RowMajor> > scalar4(add_const(scalar2).data());
  231. scalar1() = 7;
  232. scalar2() = 13;
  233. VERIFY_IS_EQUAL(scalar1.rank(), 0);
  234. VERIFY_IS_EQUAL(scalar1.size(), 1);
  235. VERIFY_IS_EQUAL(scalar3(), 7);
  236. VERIFY_IS_EQUAL(scalar4(), 13);
  237. }
  238. static void test_0d_const_tensor_map()
  239. {
  240. Tensor<int, 0> scalar1;
  241. Tensor<int, 0, RowMajor> scalar2;
  242. const TensorMap<Tensor<int, 0> > scalar3(scalar1.data());
  243. const TensorMap<Tensor<int, 0, RowMajor> > scalar4(scalar2.data());
  244. // Although TensorMap is constant, we still can write to the underlying
  245. // storage, because we map over non-constant Tensor.
  246. scalar3() = 7;
  247. scalar4() = 13;
  248. VERIFY_IS_EQUAL(scalar1(), 7);
  249. VERIFY_IS_EQUAL(scalar2(), 13);
  250. // Pointer to the underlying storage is also non-const.
  251. scalar3.data()[0] = 8;
  252. scalar4.data()[0] = 14;
  253. VERIFY_IS_EQUAL(scalar1(), 8);
  254. VERIFY_IS_EQUAL(scalar2(), 14);
  255. }
  256. EIGEN_DECLARE_TEST(cxx11_tensor_map)
  257. {
  258. CALL_SUBTEST(test_0d());
  259. CALL_SUBTEST(test_1d());
  260. CALL_SUBTEST(test_2d());
  261. CALL_SUBTEST(test_3d());
  262. CALL_SUBTEST(test_from_tensor());
  263. CALL_SUBTEST(test_casting());
  264. CALL_SUBTEST(test_0d_const_tensor());
  265. CALL_SUBTEST(test_0d_const_tensor_map());
  266. }