cxx11_tensor_reverse.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Navdeep Jaitly <ndjaitly@google.com and
  5. // Benoit Steiner <benoit.steiner.goog@gmail.com>
  6. //
  7. // This Source Code Form is subject to the terms of the Mozilla
  8. // Public License v. 2.0. If a copy of the MPL was not distributed
  9. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  10. #include "main.h"
  11. #include <Eigen/CXX11/Tensor>
  12. using Eigen::Tensor;
  13. using Eigen::array;
  14. template <int DataLayout>
  15. static void test_simple_reverse()
  16. {
  17. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  18. tensor.setRandom();
  19. array<bool, 4> dim_rev;
  20. dim_rev[0] = false;
  21. dim_rev[1] = true;
  22. dim_rev[2] = true;
  23. dim_rev[3] = false;
  24. Tensor<float, 4, DataLayout> reversed_tensor;
  25. reversed_tensor = tensor.reverse(dim_rev);
  26. VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
  27. VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
  28. VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
  29. VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
  30. for (int i = 0; i < 2; ++i) {
  31. for (int j = 0; j < 3; ++j) {
  32. for (int k = 0; k < 5; ++k) {
  33. for (int l = 0; l < 7; ++l) {
  34. VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(i,2-j,4-k,l));
  35. }
  36. }
  37. }
  38. }
  39. dim_rev[0] = true;
  40. dim_rev[1] = false;
  41. dim_rev[2] = false;
  42. dim_rev[3] = false;
  43. reversed_tensor = tensor.reverse(dim_rev);
  44. VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
  45. VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
  46. VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
  47. VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
  48. for (int i = 0; i < 2; ++i) {
  49. for (int j = 0; j < 3; ++j) {
  50. for (int k = 0; k < 5; ++k) {
  51. for (int l = 0; l < 7; ++l) {
  52. VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,l));
  53. }
  54. }
  55. }
  56. }
  57. dim_rev[0] = true;
  58. dim_rev[1] = false;
  59. dim_rev[2] = false;
  60. dim_rev[3] = true;
  61. reversed_tensor = tensor.reverse(dim_rev);
  62. VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2);
  63. VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3);
  64. VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5);
  65. VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7);
  66. for (int i = 0; i < 2; ++i) {
  67. for (int j = 0; j < 3; ++j) {
  68. for (int k = 0; k < 5; ++k) {
  69. for (int l = 0; l < 7; ++l) {
  70. VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,6-l));
  71. }
  72. }
  73. }
  74. }
  75. }
  76. template <int DataLayout>
  77. static void test_expr_reverse(bool LValue)
  78. {
  79. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  80. tensor.setRandom();
  81. array<bool, 4> dim_rev;
  82. dim_rev[0] = false;
  83. dim_rev[1] = true;
  84. dim_rev[2] = false;
  85. dim_rev[3] = true;
  86. Tensor<float, 4, DataLayout> expected(2, 3, 5, 7);
  87. if (LValue) {
  88. expected.reverse(dim_rev) = tensor;
  89. } else {
  90. expected = tensor.reverse(dim_rev);
  91. }
  92. Tensor<float, 4, DataLayout> result(2,3,5,7);
  93. array<ptrdiff_t, 4> src_slice_dim;
  94. src_slice_dim[0] = 2;
  95. src_slice_dim[1] = 3;
  96. src_slice_dim[2] = 1;
  97. src_slice_dim[3] = 7;
  98. array<ptrdiff_t, 4> src_slice_start;
  99. src_slice_start[0] = 0;
  100. src_slice_start[1] = 0;
  101. src_slice_start[2] = 0;
  102. src_slice_start[3] = 0;
  103. array<ptrdiff_t, 4> dst_slice_dim = src_slice_dim;
  104. array<ptrdiff_t, 4> dst_slice_start = src_slice_start;
  105. for (int i = 0; i < 5; ++i) {
  106. if (LValue) {
  107. result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
  108. tensor.slice(src_slice_start, src_slice_dim);
  109. } else {
  110. result.slice(dst_slice_start, dst_slice_dim) =
  111. tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev);
  112. }
  113. src_slice_start[2] += 1;
  114. dst_slice_start[2] += 1;
  115. }
  116. VERIFY_IS_EQUAL(result.dimension(0), 2);
  117. VERIFY_IS_EQUAL(result.dimension(1), 3);
  118. VERIFY_IS_EQUAL(result.dimension(2), 5);
  119. VERIFY_IS_EQUAL(result.dimension(3), 7);
  120. for (int i = 0; i < expected.dimension(0); ++i) {
  121. for (int j = 0; j < expected.dimension(1); ++j) {
  122. for (int k = 0; k < expected.dimension(2); ++k) {
  123. for (int l = 0; l < expected.dimension(3); ++l) {
  124. VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
  125. }
  126. }
  127. }
  128. }
  129. dst_slice_start[2] = 0;
  130. result.setRandom();
  131. for (int i = 0; i < 5; ++i) {
  132. if (LValue) {
  133. result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) =
  134. tensor.slice(dst_slice_start, dst_slice_dim);
  135. } else {
  136. result.slice(dst_slice_start, dst_slice_dim) =
  137. tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim);
  138. }
  139. dst_slice_start[2] += 1;
  140. }
  141. for (int i = 0; i < expected.dimension(0); ++i) {
  142. for (int j = 0; j < expected.dimension(1); ++j) {
  143. for (int k = 0; k < expected.dimension(2); ++k) {
  144. for (int l = 0; l < expected.dimension(3); ++l) {
  145. VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
  146. }
  147. }
  148. }
  149. }
  150. }
  151. EIGEN_DECLARE_TEST(cxx11_tensor_reverse)
  152. {
  153. CALL_SUBTEST(test_simple_reverse<ColMajor>());
  154. CALL_SUBTEST(test_simple_reverse<RowMajor>());
  155. CALL_SUBTEST(test_expr_reverse<ColMajor>(true));
  156. CALL_SUBTEST(test_expr_reverse<RowMajor>(true));
  157. CALL_SUBTEST(test_expr_reverse<ColMajor>(false));
  158. CALL_SUBTEST(test_expr_reverse<RowMajor>(false));
  159. }