cxx11_tensor_padding.cpp 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. template<int DataLayout>
  13. static void test_simple_padding()
  14. {
  15. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  16. tensor.setRandom();
  17. array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
  18. paddings[0] = std::make_pair(0, 0);
  19. paddings[1] = std::make_pair(2, 1);
  20. paddings[2] = std::make_pair(3, 4);
  21. paddings[3] = std::make_pair(0, 0);
  22. Tensor<float, 4, DataLayout> padded;
  23. padded = tensor.pad(paddings);
  24. VERIFY_IS_EQUAL(padded.dimension(0), 2+0);
  25. VERIFY_IS_EQUAL(padded.dimension(1), 3+3);
  26. VERIFY_IS_EQUAL(padded.dimension(2), 5+7);
  27. VERIFY_IS_EQUAL(padded.dimension(3), 7+0);
  28. for (int i = 0; i < 2; ++i) {
  29. for (int j = 0; j < 6; ++j) {
  30. for (int k = 0; k < 12; ++k) {
  31. for (int l = 0; l < 7; ++l) {
  32. if (j >= 2 && j < 5 && k >= 3 && k < 8) {
  33. VERIFY_IS_EQUAL(padded(i,j,k,l), tensor(i,j-2,k-3,l));
  34. } else {
  35. VERIFY_IS_EQUAL(padded(i,j,k,l), 0.0f);
  36. }
  37. }
  38. }
  39. }
  40. }
  41. }
  42. template<int DataLayout>
  43. static void test_padded_expr()
  44. {
  45. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  46. tensor.setRandom();
  47. array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
  48. paddings[0] = std::make_pair(0, 0);
  49. paddings[1] = std::make_pair(2, 1);
  50. paddings[2] = std::make_pair(3, 4);
  51. paddings[3] = std::make_pair(0, 0);
  52. Eigen::DSizes<ptrdiff_t, 2> reshape_dims;
  53. reshape_dims[0] = 12;
  54. reshape_dims[1] = 84;
  55. Tensor<float, 2, DataLayout> result;
  56. result = tensor.pad(paddings).reshape(reshape_dims);
  57. for (int i = 0; i < 2; ++i) {
  58. for (int j = 0; j < 6; ++j) {
  59. for (int k = 0; k < 12; ++k) {
  60. for (int l = 0; l < 7; ++l) {
  61. const float result_value = DataLayout == ColMajor ?
  62. result(i+2*j,k+12*l) : result(j+6*i,l+7*k);
  63. if (j >= 2 && j < 5 && k >= 3 && k < 8) {
  64. VERIFY_IS_EQUAL(result_value, tensor(i,j-2,k-3,l));
  65. } else {
  66. VERIFY_IS_EQUAL(result_value, 0.0f);
  67. }
  68. }
  69. }
  70. }
  71. }
  72. }
  73. EIGEN_DECLARE_TEST(cxx11_tensor_padding)
  74. {
  75. CALL_SUBTEST(test_simple_padding<ColMajor>());
  76. CALL_SUBTEST(test_simple_padding<RowMajor>());
  77. CALL_SUBTEST(test_padded_expr<ColMajor>());
  78. CALL_SUBTEST(test_padded_expr<RowMajor>());
  79. }