cxx11_tensor_concatenation.cpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. template<int DataLayout>
  13. static void test_dimension_failures()
  14. {
  15. Tensor<int, 3, DataLayout> left(2, 3, 1);
  16. Tensor<int, 3, DataLayout> right(3, 3, 1);
  17. left.setRandom();
  18. right.setRandom();
  19. // Okay; other dimensions are equal.
  20. Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
  21. // Dimension mismatches.
  22. VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
  23. VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
  24. // Axis > NumDims or < 0.
  25. VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
  26. VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
  27. }
  28. template<int DataLayout>
  29. static void test_static_dimension_failure()
  30. {
  31. Tensor<int, 2, DataLayout> left(2, 3);
  32. Tensor<int, 3, DataLayout> right(2, 3, 1);
  33. #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
  34. // Technically compatible, but we static assert that the inputs have same
  35. // NumDims.
  36. Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
  37. #endif
  38. // This can be worked around in this case.
  39. Tensor<int, 3, DataLayout> concatenation = left
  40. .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
  41. .concatenate(right, 0);
  42. Tensor<int, 2, DataLayout> alternative = left
  43. // Clang compiler break with {{{}}} with an ambiguous error on copy constructor
  44. // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H.
  45. // Solution:
  46. // either the code should change to
  47. // Tensor<int, 2>::Dimensions{{2, 3}}
  48. // or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}}
  49. .concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
  50. }
  51. template<int DataLayout>
  52. static void test_simple_concatenation()
  53. {
  54. Tensor<int, 3, DataLayout> left(2, 3, 1);
  55. Tensor<int, 3, DataLayout> right(2, 3, 1);
  56. left.setRandom();
  57. right.setRandom();
  58. Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
  59. VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
  60. VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
  61. VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
  62. for (int j = 0; j < 3; ++j) {
  63. for (int i = 0; i < 2; ++i) {
  64. VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
  65. }
  66. for (int i = 2; i < 4; ++i) {
  67. VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
  68. }
  69. }
  70. concatenation = left.concatenate(right, 1);
  71. VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
  72. VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
  73. VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
  74. for (int i = 0; i < 2; ++i) {
  75. for (int j = 0; j < 3; ++j) {
  76. VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
  77. }
  78. for (int j = 3; j < 6; ++j) {
  79. VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
  80. }
  81. }
  82. concatenation = left.concatenate(right, 2);
  83. VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
  84. VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
  85. VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
  86. for (int i = 0; i < 2; ++i) {
  87. for (int j = 0; j < 3; ++j) {
  88. VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
  89. VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
  90. }
  91. }
  92. }
  93. // TODO(phli): Add test once we have a real vectorized implementation.
  94. // static void test_vectorized_concatenation() {}
  95. static void test_concatenation_as_lvalue()
  96. {
  97. Tensor<int, 2> t1(2, 3);
  98. Tensor<int, 2> t2(2, 3);
  99. t1.setRandom();
  100. t2.setRandom();
  101. Tensor<int, 2> result(4, 3);
  102. result.setRandom();
  103. t1.concatenate(t2, 0) = result;
  104. for (int i = 0; i < 2; ++i) {
  105. for (int j = 0; j < 3; ++j) {
  106. VERIFY_IS_EQUAL(t1(i, j), result(i, j));
  107. VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
  108. }
  109. }
  110. }
  111. EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
  112. {
  113. CALL_SUBTEST(test_dimension_failures<ColMajor>());
  114. CALL_SUBTEST(test_dimension_failures<RowMajor>());
  115. CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
  116. CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
  117. CALL_SUBTEST(test_simple_concatenation<ColMajor>());
  118. CALL_SUBTEST(test_simple_concatenation<RowMajor>());
  119. // CALL_SUBTEST(test_vectorized_concatenation());
  120. CALL_SUBTEST(test_concatenation_as_lvalue());
  121. }