cxx11_tensor_scan.cpp 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2016 Igor Babuschkin <igor@babuschk.in>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <limits>
  11. #include <numeric>
  12. #include <Eigen/CXX11/Tensor>
  13. using Eigen::Tensor;
  14. template <int DataLayout, typename Type=float, bool Exclusive = false>
  15. static void test_1d_scan()
  16. {
  17. int size = 50;
  18. Tensor<Type, 1, DataLayout> tensor(size);
  19. tensor.setRandom();
  20. Tensor<Type, 1, DataLayout> result = tensor.cumsum(0, Exclusive);
  21. VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0));
  22. float accum = 0;
  23. for (int i = 0; i < size; i++) {
  24. if (Exclusive) {
  25. VERIFY_IS_EQUAL(result(i), accum);
  26. accum += tensor(i);
  27. } else {
  28. accum += tensor(i);
  29. VERIFY_IS_EQUAL(result(i), accum);
  30. }
  31. }
  32. accum = 1;
  33. result = tensor.cumprod(0, Exclusive);
  34. for (int i = 0; i < size; i++) {
  35. if (Exclusive) {
  36. VERIFY_IS_EQUAL(result(i), accum);
  37. accum *= tensor(i);
  38. } else {
  39. accum *= tensor(i);
  40. VERIFY_IS_EQUAL(result(i), accum);
  41. }
  42. }
  43. }
  44. template <int DataLayout, typename Type=float>
  45. static void test_4d_scan()
  46. {
  47. int size = 5;
  48. Tensor<Type, 4, DataLayout> tensor(size, size, size, size);
  49. tensor.setRandom();
  50. Tensor<Type, 4, DataLayout> result(size, size, size, size);
  51. result = tensor.cumsum(0);
  52. float accum = 0;
  53. for (int i = 0; i < size; i++) {
  54. accum += tensor(i, 1, 2, 3);
  55. VERIFY_IS_EQUAL(result(i, 1, 2, 3), accum);
  56. }
  57. result = tensor.cumsum(1);
  58. accum = 0;
  59. for (int i = 0; i < size; i++) {
  60. accum += tensor(1, i, 2, 3);
  61. VERIFY_IS_EQUAL(result(1, i, 2, 3), accum);
  62. }
  63. result = tensor.cumsum(2);
  64. accum = 0;
  65. for (int i = 0; i < size; i++) {
  66. accum += tensor(1, 2, i, 3);
  67. VERIFY_IS_EQUAL(result(1, 2, i, 3), accum);
  68. }
  69. result = tensor.cumsum(3);
  70. accum = 0;
  71. for (int i = 0; i < size; i++) {
  72. accum += tensor(1, 2, 3, i);
  73. VERIFY_IS_EQUAL(result(1, 2, 3, i), accum);
  74. }
  75. }
  76. template <int DataLayout>
  77. static void test_tensor_maps() {
  78. int inputs[20];
  79. TensorMap<Tensor<int, 1, DataLayout> > tensor_map(inputs, 20);
  80. tensor_map.setRandom();
  81. Tensor<int, 1, DataLayout> result = tensor_map.cumsum(0);
  82. int accum = 0;
  83. for (int i = 0; i < 20; ++i) {
  84. accum += tensor_map(i);
  85. VERIFY_IS_EQUAL(result(i), accum);
  86. }
  87. }
  88. EIGEN_DECLARE_TEST(cxx11_tensor_scan) {
  89. CALL_SUBTEST((test_1d_scan<ColMajor, float, true>()));
  90. CALL_SUBTEST((test_1d_scan<ColMajor, float, false>()));
  91. CALL_SUBTEST((test_1d_scan<RowMajor, float, true>()));
  92. CALL_SUBTEST((test_1d_scan<RowMajor, float, false>()));
  93. CALL_SUBTEST(test_4d_scan<ColMajor>());
  94. CALL_SUBTEST(test_4d_scan<RowMajor>());
  95. CALL_SUBTEST(test_tensor_maps<ColMajor>());
  96. CALL_SUBTEST(test_tensor_maps<RowMajor>());
  97. }