cxx11_tensor_argmax.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2015 Eugene Brevdo <ebrevdo@google.com>
  5. // Benoit Steiner <benoit.steiner.goog@gmail.com>
  6. //
  7. // This Source Code Form is subject to the terms of the Mozilla
  8. // Public License v. 2.0. If a copy of the MPL was not distributed
  9. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  10. #include "main.h"
  11. #include <Eigen/CXX11/Tensor>
  12. using Eigen::Tensor;
  13. using Eigen::array;
  14. using Eigen::Tuple;
  15. template <int DataLayout>
  16. static void test_simple_index_tuples()
  17. {
  18. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  19. tensor.setRandom();
  20. tensor = (tensor + tensor.constant(0.5)).log();
  21. Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
  22. index_tuples = tensor.index_tuples();
  23. for (DenseIndex n = 0; n < 2*3*5*7; ++n) {
  24. const Tuple<DenseIndex, float>& v = index_tuples.coeff(n);
  25. VERIFY_IS_EQUAL(v.first, n);
  26. VERIFY_IS_EQUAL(v.second, tensor.coeff(n));
  27. }
  28. }
  29. template <int DataLayout>
  30. static void test_index_tuples_dim()
  31. {
  32. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  33. tensor.setRandom();
  34. tensor = (tensor + tensor.constant(0.5)).log();
  35. Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
  36. index_tuples = tensor.index_tuples();
  37. for (Eigen::DenseIndex n = 0; n < tensor.size(); ++n) {
  38. const Tuple<DenseIndex, float>& v = index_tuples(n); //(i, j, k, l);
  39. VERIFY_IS_EQUAL(v.first, n);
  40. VERIFY_IS_EQUAL(v.second, tensor(n));
  41. }
  42. }
  43. template <int DataLayout>
  44. static void test_argmax_tuple_reducer()
  45. {
  46. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  47. tensor.setRandom();
  48. tensor = (tensor + tensor.constant(0.5)).log();
  49. Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
  50. index_tuples = tensor.index_tuples();
  51. Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
  52. DimensionList<DenseIndex, 4> dims;
  53. reduced = index_tuples.reduce(
  54. dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
  55. Tensor<float, 0, DataLayout> maxi = tensor.maximum();
  56. VERIFY_IS_EQUAL(maxi(), reduced(0).second);
  57. array<DenseIndex, 3> reduce_dims;
  58. for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
  59. Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
  60. reduced_by_dims = index_tuples.reduce(
  61. reduce_dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
  62. Tensor<float, 1, DataLayout> max_by_dims = tensor.maximum(reduce_dims);
  63. for (int l = 0; l < 7; ++l) {
  64. VERIFY_IS_EQUAL(max_by_dims(l), reduced_by_dims(l).second);
  65. }
  66. }
  67. template <int DataLayout>
  68. static void test_argmin_tuple_reducer()
  69. {
  70. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  71. tensor.setRandom();
  72. tensor = (tensor + tensor.constant(0.5)).log();
  73. Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
  74. index_tuples = tensor.index_tuples();
  75. Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
  76. DimensionList<DenseIndex, 4> dims;
  77. reduced = index_tuples.reduce(
  78. dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
  79. Tensor<float, 0, DataLayout> mini = tensor.minimum();
  80. VERIFY_IS_EQUAL(mini(), reduced(0).second);
  81. array<DenseIndex, 3> reduce_dims;
  82. for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
  83. Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
  84. reduced_by_dims = index_tuples.reduce(
  85. reduce_dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
  86. Tensor<float, 1, DataLayout> min_by_dims = tensor.minimum(reduce_dims);
  87. for (int l = 0; l < 7; ++l) {
  88. VERIFY_IS_EQUAL(min_by_dims(l), reduced_by_dims(l).second);
  89. }
  90. }
  91. template <int DataLayout>
  92. static void test_simple_argmax()
  93. {
  94. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  95. tensor.setRandom();
  96. tensor = (tensor + tensor.constant(0.5)).log();
  97. tensor(0,0,0,0) = 10.0;
  98. Tensor<DenseIndex, 0, DataLayout> tensor_argmax;
  99. tensor_argmax = tensor.argmax();
  100. VERIFY_IS_EQUAL(tensor_argmax(0), 0);
  101. tensor(1,2,4,6) = 20.0;
  102. tensor_argmax = tensor.argmax();
  103. VERIFY_IS_EQUAL(tensor_argmax(0), 2*3*5*7 - 1);
  104. }
  105. template <int DataLayout>
  106. static void test_simple_argmin()
  107. {
  108. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  109. tensor.setRandom();
  110. tensor = (tensor + tensor.constant(0.5)).log();
  111. tensor(0,0,0,0) = -10.0;
  112. Tensor<DenseIndex, 0, DataLayout> tensor_argmin;
  113. tensor_argmin = tensor.argmin();
  114. VERIFY_IS_EQUAL(tensor_argmin(0), 0);
  115. tensor(1,2,4,6) = -20.0;
  116. tensor_argmin = tensor.argmin();
  117. VERIFY_IS_EQUAL(tensor_argmin(0), 2*3*5*7 - 1);
  118. }
  119. template <int DataLayout>
  120. static void test_argmax_dim()
  121. {
  122. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  123. std::vector<int> dims {2, 3, 5, 7};
  124. for (int dim = 0; dim < 4; ++dim) {
  125. tensor.setRandom();
  126. tensor = (tensor + tensor.constant(0.5)).log();
  127. Tensor<DenseIndex, 3, DataLayout> tensor_argmax;
  128. array<DenseIndex, 4> ix;
  129. for (int i = 0; i < 2; ++i) {
  130. for (int j = 0; j < 3; ++j) {
  131. for (int k = 0; k < 5; ++k) {
  132. for (int l = 0; l < 7; ++l) {
  133. ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
  134. if (ix[dim] != 0) continue;
  135. // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = 10.0
  136. tensor(ix) = 10.0;
  137. }
  138. }
  139. }
  140. }
  141. tensor_argmax = tensor.argmax(dim);
  142. VERIFY_IS_EQUAL(tensor_argmax.size(),
  143. ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
  144. for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
  145. // Expect max to be in the first index of the reduced dimension
  146. VERIFY_IS_EQUAL(tensor_argmax.data()[n], 0);
  147. }
  148. for (int i = 0; i < 2; ++i) {
  149. for (int j = 0; j < 3; ++j) {
  150. for (int k = 0; k < 5; ++k) {
  151. for (int l = 0; l < 7; ++l) {
  152. ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
  153. if (ix[dim] != tensor.dimension(dim) - 1) continue;
  154. // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = 20.0
  155. tensor(ix) = 20.0;
  156. }
  157. }
  158. }
  159. }
  160. tensor_argmax = tensor.argmax(dim);
  161. VERIFY_IS_EQUAL(tensor_argmax.size(),
  162. ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
  163. for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
  164. // Expect max to be in the last index of the reduced dimension
  165. VERIFY_IS_EQUAL(tensor_argmax.data()[n], tensor.dimension(dim) - 1);
  166. }
  167. }
  168. }
  169. template <int DataLayout>
  170. static void test_argmin_dim()
  171. {
  172. Tensor<float, 4, DataLayout> tensor(2,3,5,7);
  173. std::vector<int> dims {2, 3, 5, 7};
  174. for (int dim = 0; dim < 4; ++dim) {
  175. tensor.setRandom();
  176. tensor = (tensor + tensor.constant(0.5)).log();
  177. Tensor<DenseIndex, 3, DataLayout> tensor_argmin;
  178. array<DenseIndex, 4> ix;
  179. for (int i = 0; i < 2; ++i) {
  180. for (int j = 0; j < 3; ++j) {
  181. for (int k = 0; k < 5; ++k) {
  182. for (int l = 0; l < 7; ++l) {
  183. ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
  184. if (ix[dim] != 0) continue;
  185. // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = -10.0
  186. tensor(ix) = -10.0;
  187. }
  188. }
  189. }
  190. }
  191. tensor_argmin = tensor.argmin(dim);
  192. VERIFY_IS_EQUAL(tensor_argmin.size(),
  193. ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
  194. for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
  195. // Expect min to be in the first index of the reduced dimension
  196. VERIFY_IS_EQUAL(tensor_argmin.data()[n], 0);
  197. }
  198. for (int i = 0; i < 2; ++i) {
  199. for (int j = 0; j < 3; ++j) {
  200. for (int k = 0; k < 5; ++k) {
  201. for (int l = 0; l < 7; ++l) {
  202. ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
  203. if (ix[dim] != tensor.dimension(dim) - 1) continue;
  204. // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = -20.0
  205. tensor(ix) = -20.0;
  206. }
  207. }
  208. }
  209. }
  210. tensor_argmin = tensor.argmin(dim);
  211. VERIFY_IS_EQUAL(tensor_argmin.size(),
  212. ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
  213. for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
  214. // Expect min to be in the last index of the reduced dimension
  215. VERIFY_IS_EQUAL(tensor_argmin.data()[n], tensor.dimension(dim) - 1);
  216. }
  217. }
  218. }
  219. EIGEN_DECLARE_TEST(cxx11_tensor_argmax)
  220. {
  221. CALL_SUBTEST(test_simple_index_tuples<RowMajor>());
  222. CALL_SUBTEST(test_simple_index_tuples<ColMajor>());
  223. CALL_SUBTEST(test_index_tuples_dim<RowMajor>());
  224. CALL_SUBTEST(test_index_tuples_dim<ColMajor>());
  225. CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>());
  226. CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>());
  227. CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>());
  228. CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>());
  229. CALL_SUBTEST(test_simple_argmax<RowMajor>());
  230. CALL_SUBTEST(test_simple_argmax<ColMajor>());
  231. CALL_SUBTEST(test_simple_argmin<RowMajor>());
  232. CALL_SUBTEST(test_simple_argmin<ColMajor>());
  233. CALL_SUBTEST(test_argmax_dim<RowMajor>());
  234. CALL_SUBTEST(test_argmax_dim<ColMajor>());
  235. CALL_SUBTEST(test_argmin_dim<RowMajor>());
  236. CALL_SUBTEST(test_argmin_dim<ColMajor>());
  237. }