cxx11_tensor_striding.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. template<int DataLayout>
  13. static void test_simple_striding()
  14. {
  15. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  16. tensor.setRandom();
  17. array<ptrdiff_t, 4> strides;
  18. strides[0] = 1;
  19. strides[1] = 1;
  20. strides[2] = 1;
  21. strides[3] = 1;
  22. Tensor<float, 4, DataLayout> no_stride;
  23. no_stride = tensor.stride(strides);
  24. VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
  25. VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
  26. VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
  27. VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
  28. for (int i = 0; i < 2; ++i) {
  29. for (int j = 0; j < 3; ++j) {
  30. for (int k = 0; k < 5; ++k) {
  31. for (int l = 0; l < 7; ++l) {
  32. VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
  33. }
  34. }
  35. }
  36. }
  37. strides[0] = 2;
  38. strides[1] = 4;
  39. strides[2] = 2;
  40. strides[3] = 3;
  41. Tensor<float, 4, DataLayout> stride;
  42. stride = tensor.stride(strides);
  43. VERIFY_IS_EQUAL(stride.dimension(0), 1);
  44. VERIFY_IS_EQUAL(stride.dimension(1), 1);
  45. VERIFY_IS_EQUAL(stride.dimension(2), 3);
  46. VERIFY_IS_EQUAL(stride.dimension(3), 3);
  47. for (int i = 0; i < 1; ++i) {
  48. for (int j = 0; j < 1; ++j) {
  49. for (int k = 0; k < 3; ++k) {
  50. for (int l = 0; l < 3; ++l) {
  51. VERIFY_IS_EQUAL(tensor(2*i,4*j,2*k,3*l), stride(i,j,k,l));
  52. }
  53. }
  54. }
  55. }
  56. }
  57. template<int DataLayout>
  58. static void test_striding_as_lvalue()
  59. {
  60. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  61. tensor.setRandom();
  62. array<ptrdiff_t, 4> strides;
  63. strides[0] = 2;
  64. strides[1] = 4;
  65. strides[2] = 2;
  66. strides[3] = 3;
  67. Tensor<float, 4, DataLayout> result(3, 12, 10, 21);
  68. result.stride(strides) = tensor;
  69. for (int i = 0; i < 2; ++i) {
  70. for (int j = 0; j < 3; ++j) {
  71. for (int k = 0; k < 5; ++k) {
  72. for (int l = 0; l < 7; ++l) {
  73. VERIFY_IS_EQUAL(tensor(i,j,k,l), result(2*i,4*j,2*k,3*l));
  74. }
  75. }
  76. }
  77. }
  78. array<ptrdiff_t, 4> no_strides;
  79. no_strides[0] = 1;
  80. no_strides[1] = 1;
  81. no_strides[2] = 1;
  82. no_strides[3] = 1;
  83. Tensor<float, 4, DataLayout> result2(3, 12, 10, 21);
  84. result2.stride(strides) = tensor.stride(no_strides);
  85. for (int i = 0; i < 2; ++i) {
  86. for (int j = 0; j < 3; ++j) {
  87. for (int k = 0; k < 5; ++k) {
  88. for (int l = 0; l < 7; ++l) {
  89. VERIFY_IS_EQUAL(tensor(i,j,k,l), result2(2*i,4*j,2*k,3*l));
  90. }
  91. }
  92. }
  93. }
  94. }
  95. EIGEN_DECLARE_TEST(cxx11_tensor_striding)
  96. {
  97. CALL_SUBTEST(test_simple_striding<ColMajor>());
  98. CALL_SUBTEST(test_simple_striding<RowMajor>());
  99. CALL_SUBTEST(test_striding_as_lvalue<ColMajor>());
  100. CALL_SUBTEST(test_striding_as_lvalue<RowMajor>());
  101. }