cxx11_tensor_sugar.cpp 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #include "main.h"
  2. #include <Eigen/CXX11/Tensor>
  3. using Eigen::Tensor;
  4. using Eigen::RowMajor;
  5. static void test_comparison_sugar() {
  6. // we already trust comparisons between tensors, we're simply checking that
  7. // the sugared versions are doing the same thing
  8. Tensor<int, 3> t(6, 7, 5);
  9. t.setRandom();
  10. // make sure we have at least one value == 0
  11. t(0,0,0) = 0;
  12. Tensor<bool,0> b;
  13. #define TEST_TENSOR_EQUAL(e1, e2) \
  14. b = ((e1) == (e2)).all(); \
  15. VERIFY(b())
  16. #define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0))
  17. TEST_OP(==);
  18. TEST_OP(!=);
  19. TEST_OP(<=);
  20. TEST_OP(>=);
  21. TEST_OP(<);
  22. TEST_OP(>);
  23. #undef TEST_OP
  24. #undef TEST_TENSOR_EQUAL
  25. }
  26. static void test_scalar_sugar_add_mul() {
  27. Tensor<float, 3> A(6, 7, 5);
  28. Tensor<float, 3> B(6, 7, 5);
  29. A.setRandom();
  30. B.setRandom();
  31. const float alpha = 0.43f;
  32. const float beta = 0.21f;
  33. const float gamma = 0.14f;
  34. Tensor<float, 3> R = A.constant(gamma) + A * A.constant(alpha) + B * B.constant(beta);
  35. Tensor<float, 3> S = A * alpha + B * beta + gamma;
  36. Tensor<float, 3> T = gamma + alpha * A + beta * B;
  37. for (int i = 0; i < 6*7*5; ++i) {
  38. VERIFY_IS_APPROX(R(i), S(i));
  39. VERIFY_IS_APPROX(R(i), T(i));
  40. }
  41. }
  42. static void test_scalar_sugar_sub_div() {
  43. Tensor<float, 3> A(6, 7, 5);
  44. Tensor<float, 3> B(6, 7, 5);
  45. A.setRandom();
  46. B.setRandom();
  47. const float alpha = 0.43f;
  48. const float beta = 0.21f;
  49. const float gamma = 0.14f;
  50. const float delta = 0.32f;
  51. Tensor<float, 3> R = A.constant(gamma) - A / A.constant(alpha)
  52. - B.constant(beta) / B - A.constant(delta);
  53. Tensor<float, 3> S = gamma - A / alpha - beta / B - delta;
  54. for (int i = 0; i < 6*7*5; ++i) {
  55. VERIFY_IS_APPROX(R(i), S(i));
  56. }
  57. }
  58. EIGEN_DECLARE_TEST(cxx11_tensor_sugar)
  59. {
  60. CALL_SUBTEST(test_comparison_sugar());
  61. CALL_SUBTEST(test_scalar_sugar_add_mul());
  62. CALL_SUBTEST(test_scalar_sugar_sub_div());
  63. }