cxx11_tensor_convolution.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "main.h"
  10. #include <Eigen/CXX11/Tensor>
  11. using Eigen::Tensor;
  12. using Eigen::DefaultDevice;
  13. template <int DataLayout>
  14. static void test_evals()
  15. {
  16. Tensor<float, 2, DataLayout> input(3, 3);
  17. Tensor<float, 1, DataLayout> kernel(2);
  18. input.setRandom();
  19. kernel.setRandom();
  20. Tensor<float, 2, DataLayout> result(2,3);
  21. result.setZero();
  22. Eigen::array<Tensor<float, 2>::Index, 1> dims3;
  23. dims3[0] = 0;
  24. typedef TensorEvaluator<decltype(input.convolve(kernel, dims3)), DefaultDevice> Evaluator;
  25. Evaluator eval(input.convolve(kernel, dims3), DefaultDevice());
  26. eval.evalTo(result.data());
  27. EIGEN_STATIC_ASSERT(Evaluator::NumDims==2ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
  28. VERIFY_IS_EQUAL(eval.dimensions()[0], 2);
  29. VERIFY_IS_EQUAL(eval.dimensions()[1], 3);
  30. VERIFY_IS_APPROX(result(0,0), input(0,0)*kernel(0) + input(1,0)*kernel(1)); // index 0
  31. VERIFY_IS_APPROX(result(0,1), input(0,1)*kernel(0) + input(1,1)*kernel(1)); // index 2
  32. VERIFY_IS_APPROX(result(0,2), input(0,2)*kernel(0) + input(1,2)*kernel(1)); // index 4
  33. VERIFY_IS_APPROX(result(1,0), input(1,0)*kernel(0) + input(2,0)*kernel(1)); // index 1
  34. VERIFY_IS_APPROX(result(1,1), input(1,1)*kernel(0) + input(2,1)*kernel(1)); // index 3
  35. VERIFY_IS_APPROX(result(1,2), input(1,2)*kernel(0) + input(2,2)*kernel(1)); // index 5
  36. }
  37. template <int DataLayout>
  38. static void test_expr()
  39. {
  40. Tensor<float, 2, DataLayout> input(3, 3);
  41. Tensor<float, 2, DataLayout> kernel(2, 2);
  42. input.setRandom();
  43. kernel.setRandom();
  44. Tensor<float, 2, DataLayout> result(2,2);
  45. Eigen::array<ptrdiff_t, 2> dims;
  46. dims[0] = 0;
  47. dims[1] = 1;
  48. result = input.convolve(kernel, dims);
  49. VERIFY_IS_APPROX(result(0,0), input(0,0)*kernel(0,0) + input(0,1)*kernel(0,1) +
  50. input(1,0)*kernel(1,0) + input(1,1)*kernel(1,1));
  51. VERIFY_IS_APPROX(result(0,1), input(0,1)*kernel(0,0) + input(0,2)*kernel(0,1) +
  52. input(1,1)*kernel(1,0) + input(1,2)*kernel(1,1));
  53. VERIFY_IS_APPROX(result(1,0), input(1,0)*kernel(0,0) + input(1,1)*kernel(0,1) +
  54. input(2,0)*kernel(1,0) + input(2,1)*kernel(1,1));
  55. VERIFY_IS_APPROX(result(1,1), input(1,1)*kernel(0,0) + input(1,2)*kernel(0,1) +
  56. input(2,1)*kernel(1,0) + input(2,2)*kernel(1,1));
  57. }
  58. template <int DataLayout>
  59. static void test_modes() {
  60. Tensor<float, 1, DataLayout> input(3);
  61. Tensor<float, 1, DataLayout> kernel(3);
  62. input(0) = 1.0f;
  63. input(1) = 2.0f;
  64. input(2) = 3.0f;
  65. kernel(0) = 0.5f;
  66. kernel(1) = 1.0f;
  67. kernel(2) = 0.0f;
  68. Eigen::array<ptrdiff_t, 1> dims;
  69. dims[0] = 0;
  70. Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, 1> padding;
  71. // Emulate VALID mode (as defined in
  72. // http://docs.scipy.org/doc/numpy/reference/generated/numpy.convolve.html).
  73. padding[0] = std::make_pair(0, 0);
  74. Tensor<float, 1, DataLayout> valid(1);
  75. valid = input.pad(padding).convolve(kernel, dims);
  76. VERIFY_IS_EQUAL(valid.dimension(0), 1);
  77. VERIFY_IS_APPROX(valid(0), 2.5f);
  78. // Emulate SAME mode (as defined in
  79. // http://docs.scipy.org/doc/numpy/reference/generated/numpy.convolve.html).
  80. padding[0] = std::make_pair(1, 1);
  81. Tensor<float, 1, DataLayout> same(3);
  82. same = input.pad(padding).convolve(kernel, dims);
  83. VERIFY_IS_EQUAL(same.dimension(0), 3);
  84. VERIFY_IS_APPROX(same(0), 1.0f);
  85. VERIFY_IS_APPROX(same(1), 2.5f);
  86. VERIFY_IS_APPROX(same(2), 4.0f);
  87. // Emulate FULL mode (as defined in
  88. // http://docs.scipy.org/doc/numpy/reference/generated/numpy.convolve.html).
  89. padding[0] = std::make_pair(2, 2);
  90. Tensor<float, 1, DataLayout> full(5);
  91. full = input.pad(padding).convolve(kernel, dims);
  92. VERIFY_IS_EQUAL(full.dimension(0), 5);
  93. VERIFY_IS_APPROX(full(0), 0.0f);
  94. VERIFY_IS_APPROX(full(1), 1.0f);
  95. VERIFY_IS_APPROX(full(2), 2.5f);
  96. VERIFY_IS_APPROX(full(3), 4.0f);
  97. VERIFY_IS_APPROX(full(4), 1.5f);
  98. }
  99. template <int DataLayout>
  100. static void test_strides() {
  101. Tensor<float, 1, DataLayout> input(13);
  102. Tensor<float, 1, DataLayout> kernel(3);
  103. input.setRandom();
  104. kernel.setRandom();
  105. Eigen::array<ptrdiff_t, 1> dims;
  106. dims[0] = 0;
  107. Eigen::array<ptrdiff_t, 1> stride_of_3;
  108. stride_of_3[0] = 3;
  109. Eigen::array<ptrdiff_t, 1> stride_of_2;
  110. stride_of_2[0] = 2;
  111. Tensor<float, 1, DataLayout> result;
  112. result = input.stride(stride_of_3).convolve(kernel, dims).stride(stride_of_2);
  113. VERIFY_IS_EQUAL(result.dimension(0), 2);
  114. VERIFY_IS_APPROX(result(0), (input(0)*kernel(0) + input(3)*kernel(1) +
  115. input(6)*kernel(2)));
  116. VERIFY_IS_APPROX(result(1), (input(6)*kernel(0) + input(9)*kernel(1) +
  117. input(12)*kernel(2)));
  118. }
  119. EIGEN_DECLARE_TEST(cxx11_tensor_convolution)
  120. {
  121. CALL_SUBTEST(test_evals<ColMajor>());
  122. CALL_SUBTEST(test_evals<RowMajor>());
  123. CALL_SUBTEST(test_expr<ColMajor>());
  124. CALL_SUBTEST(test_expr<RowMajor>());
  125. CALL_SUBTEST(test_modes<ColMajor>());
  126. CALL_SUBTEST(test_modes<RowMajor>());
  127. CALL_SUBTEST(test_strides<ColMajor>());
  128. CALL_SUBTEST(test_strides<RowMajor>());
  129. }