cxx11_tensor_inflation.cpp 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2015 Ke Yang <yangke@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. template<int DataLayout>
  13. static void test_simple_inflation()
  14. {
  15. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  16. tensor.setRandom();
  17. array<ptrdiff_t, 4> strides;
  18. strides[0] = 1;
  19. strides[1] = 1;
  20. strides[2] = 1;
  21. strides[3] = 1;
  22. Tensor<float, 4, DataLayout> no_stride;
  23. no_stride = tensor.inflate(strides);
  24. VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
  25. VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
  26. VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
  27. VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
  28. for (int i = 0; i < 2; ++i) {
  29. for (int j = 0; j < 3; ++j) {
  30. for (int k = 0; k < 5; ++k) {
  31. for (int l = 0; l < 7; ++l) {
  32. VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
  33. }
  34. }
  35. }
  36. }
  37. strides[0] = 2;
  38. strides[1] = 4;
  39. strides[2] = 2;
  40. strides[3] = 3;
  41. Tensor<float, 4, DataLayout> inflated;
  42. inflated = tensor.inflate(strides);
  43. VERIFY_IS_EQUAL(inflated.dimension(0), 3);
  44. VERIFY_IS_EQUAL(inflated.dimension(1), 9);
  45. VERIFY_IS_EQUAL(inflated.dimension(2), 9);
  46. VERIFY_IS_EQUAL(inflated.dimension(3), 19);
  47. for (int i = 0; i < 3; ++i) {
  48. for (int j = 0; j < 9; ++j) {
  49. for (int k = 0; k < 9; ++k) {
  50. for (int l = 0; l < 19; ++l) {
  51. if (i % 2 == 0 &&
  52. j % 4 == 0 &&
  53. k % 2 == 0 &&
  54. l % 3 == 0) {
  55. VERIFY_IS_EQUAL(inflated(i,j,k,l),
  56. tensor(i/2, j/4, k/2, l/3));
  57. } else {
  58. VERIFY_IS_EQUAL(0, inflated(i,j,k,l));
  59. }
  60. }
  61. }
  62. }
  63. }
  64. }
  65. EIGEN_DECLARE_TEST(cxx11_tensor_inflation)
  66. {
  67. CALL_SUBTEST(test_simple_inflation<ColMajor>());
  68. CALL_SUBTEST(test_simple_inflation<RowMajor>());
  69. }