cxx11_tensor_trace.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. using Eigen::array;
  13. template <int DataLayout>
  14. static void test_0D_trace() {
  15. Tensor<float, 0, DataLayout> tensor;
  16. tensor.setRandom();
  17. array<ptrdiff_t, 0> dims;
  18. Tensor<float, 0, DataLayout> result = tensor.trace(dims);
  19. VERIFY_IS_EQUAL(result(), tensor());
  20. }
  21. template <int DataLayout>
  22. static void test_all_dimensions_trace() {
  23. Tensor<float, 3, DataLayout> tensor1(5, 5, 5);
  24. tensor1.setRandom();
  25. Tensor<float, 0, DataLayout> result1 = tensor1.trace();
  26. VERIFY_IS_EQUAL(result1.rank(), 0);
  27. float sum = 0.0f;
  28. for (int i = 0; i < 5; ++i) {
  29. sum += tensor1(i, i, i);
  30. }
  31. VERIFY_IS_EQUAL(result1(), sum);
  32. Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7);
  33. tensor2.setRandom();
  34. array<ptrdiff_t, 5> dims = { { 2, 1, 0, 3, 4 } };
  35. Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims);
  36. VERIFY_IS_EQUAL(result2.rank(), 0);
  37. sum = 0.0f;
  38. for (int i = 0; i < 7; ++i) {
  39. sum += tensor2(i, i, i, i, i);
  40. }
  41. VERIFY_IS_EQUAL(result2(), sum);
  42. }
  43. template <int DataLayout>
  44. static void test_simple_trace() {
  45. Tensor<float, 3, DataLayout> tensor1(3, 5, 3);
  46. tensor1.setRandom();
  47. array<ptrdiff_t, 2> dims1 = { { 0, 2 } };
  48. Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1);
  49. VERIFY_IS_EQUAL(result1.rank(), 1);
  50. VERIFY_IS_EQUAL(result1.dimension(0), 5);
  51. float sum = 0.0f;
  52. for (int i = 0; i < 5; ++i) {
  53. sum = 0.0f;
  54. for (int j = 0; j < 3; ++j) {
  55. sum += tensor1(j, i, j);
  56. }
  57. VERIFY_IS_EQUAL(result1(i), sum);
  58. }
  59. Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7);
  60. tensor2.setRandom();
  61. array<ptrdiff_t, 2> dims2 = { { 2, 3 } };
  62. Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2);
  63. VERIFY_IS_EQUAL(result2.rank(), 2);
  64. VERIFY_IS_EQUAL(result2.dimension(0), 5);
  65. VERIFY_IS_EQUAL(result2.dimension(1), 5);
  66. for (int i = 0; i < 5; ++i) {
  67. for (int j = 0; j < 5; ++j) {
  68. sum = 0.0f;
  69. for (int k = 0; k < 7; ++k) {
  70. sum += tensor2(i, j, k, k);
  71. }
  72. VERIFY_IS_EQUAL(result2(i, j), sum);
  73. }
  74. }
  75. array<ptrdiff_t, 2> dims3 = { { 1, 0 } };
  76. Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3);
  77. VERIFY_IS_EQUAL(result3.rank(), 2);
  78. VERIFY_IS_EQUAL(result3.dimension(0), 7);
  79. VERIFY_IS_EQUAL(result3.dimension(1), 7);
  80. for (int i = 0; i < 7; ++i) {
  81. for (int j = 0; j < 7; ++j) {
  82. sum = 0.0f;
  83. for (int k = 0; k < 5; ++k) {
  84. sum += tensor2(k, k, i, j);
  85. }
  86. VERIFY_IS_EQUAL(result3(i, j), sum);
  87. }
  88. }
  89. Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3);
  90. tensor3.setRandom();
  91. array<ptrdiff_t, 3> dims4 = { { 0, 2, 4 } };
  92. Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4);
  93. VERIFY_IS_EQUAL(result4.rank(), 2);
  94. VERIFY_IS_EQUAL(result4.dimension(0), 7);
  95. VERIFY_IS_EQUAL(result4.dimension(1), 7);
  96. for (int i = 0; i < 7; ++i) {
  97. for (int j = 0; j < 7; ++j) {
  98. sum = 0.0f;
  99. for (int k = 0; k < 3; ++k) {
  100. sum += tensor3(k, i, k, j, k);
  101. }
  102. VERIFY_IS_EQUAL(result4(i, j), sum);
  103. }
  104. }
  105. Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5);
  106. tensor4.setRandom();
  107. array<ptrdiff_t, 2> dims5 = { { 1, 3 } };
  108. Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5);
  109. VERIFY_IS_EQUAL(result5.rank(), 3);
  110. VERIFY_IS_EQUAL(result5.dimension(0), 3);
  111. VERIFY_IS_EQUAL(result5.dimension(1), 4);
  112. VERIFY_IS_EQUAL(result5.dimension(2), 5);
  113. for (int i = 0; i < 3; ++i) {
  114. for (int j = 0; j < 4; ++j) {
  115. for (int k = 0; k < 5; ++k) {
  116. sum = 0.0f;
  117. for (int l = 0; l < 7; ++l) {
  118. sum += tensor4(i, l, j, l, k);
  119. }
  120. VERIFY_IS_EQUAL(result5(i, j, k), sum);
  121. }
  122. }
  123. }
  124. }
  125. template<int DataLayout>
  126. static void test_trace_in_expr() {
  127. Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3);
  128. tensor.setRandom();
  129. array<ptrdiff_t, 2> dims = { { 1, 3 } };
  130. Tensor<float, 2, DataLayout> result(2, 5);
  131. result = result.constant(1.0f) - tensor.trace(dims);
  132. VERIFY_IS_EQUAL(result.rank(), 2);
  133. VERIFY_IS_EQUAL(result.dimension(0), 2);
  134. VERIFY_IS_EQUAL(result.dimension(1), 5);
  135. float sum = 0.0f;
  136. for (int i = 0; i < 2; ++i) {
  137. for (int j = 0; j < 5; ++j) {
  138. sum = 0.0f;
  139. for (int k = 0; k < 3; ++k) {
  140. sum += tensor(i, k, j, k);
  141. }
  142. VERIFY_IS_EQUAL(result(i, j), 1.0f - sum);
  143. }
  144. }
  145. }
  146. EIGEN_DECLARE_TEST(cxx11_tensor_trace) {
  147. CALL_SUBTEST(test_0D_trace<ColMajor>());
  148. CALL_SUBTEST(test_0D_trace<RowMajor>());
  149. CALL_SUBTEST(test_all_dimensions_trace<ColMajor>());
  150. CALL_SUBTEST(test_all_dimensions_trace<RowMajor>());
  151. CALL_SUBTEST(test_simple_trace<ColMajor>());
  152. CALL_SUBTEST(test_simple_trace<RowMajor>());
  153. CALL_SUBTEST(test_trace_in_expr<ColMajor>());
  154. CALL_SUBTEST(test_trace_in_expr<RowMajor>());
  155. }