nullary_indexing.cpp 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #include <Eigen/Core>
  2. #include <iostream>
  3. using namespace Eigen;
  4. // [functor]
  5. template<class ArgType, class RowIndexType, class ColIndexType>
  6. class indexing_functor {
  7. const ArgType &m_arg;
  8. const RowIndexType &m_rowIndices;
  9. const ColIndexType &m_colIndices;
  10. public:
  11. typedef Matrix<typename ArgType::Scalar,
  12. RowIndexType::SizeAtCompileTime,
  13. ColIndexType::SizeAtCompileTime,
  14. ArgType::Flags&RowMajorBit?RowMajor:ColMajor,
  15. RowIndexType::MaxSizeAtCompileTime,
  16. ColIndexType::MaxSizeAtCompileTime> MatrixType;
  17. indexing_functor(const ArgType& arg, const RowIndexType& row_indices, const ColIndexType& col_indices)
  18. : m_arg(arg), m_rowIndices(row_indices), m_colIndices(col_indices)
  19. {}
  20. const typename ArgType::Scalar& operator() (Index row, Index col) const {
  21. return m_arg(m_rowIndices[row], m_colIndices[col]);
  22. }
  23. };
  24. // [functor]
  25. // [function]
  26. template <class ArgType, class RowIndexType, class ColIndexType>
  27. CwiseNullaryOp<indexing_functor<ArgType,RowIndexType,ColIndexType>, typename indexing_functor<ArgType,RowIndexType,ColIndexType>::MatrixType>
  28. mat_indexing(const Eigen::MatrixBase<ArgType>& arg, const RowIndexType& row_indices, const ColIndexType& col_indices)
  29. {
  30. typedef indexing_functor<ArgType,RowIndexType,ColIndexType> Func;
  31. typedef typename Func::MatrixType MatrixType;
  32. return MatrixType::NullaryExpr(row_indices.size(), col_indices.size(), Func(arg.derived(), row_indices, col_indices));
  33. }
  34. // [function]
  35. int main()
  36. {
  37. std::cout << "[main1]\n";
  38. Eigen::MatrixXi A = Eigen::MatrixXi::Random(4,4);
  39. Array3i ri(1,2,1);
  40. ArrayXi ci(6); ci << 3,2,1,0,0,2;
  41. Eigen::MatrixXi B = mat_indexing(A, ri, ci);
  42. std::cout << "A =" << std::endl;
  43. std::cout << A << std::endl << std::endl;
  44. std::cout << "A([" << ri.transpose() << "], [" << ci.transpose() << "]) =" << std::endl;
  45. std::cout << B << std::endl;
  46. std::cout << "[main1]\n";
  47. std::cout << "[main2]\n";
  48. B = mat_indexing(A, ri+1, ci);
  49. std::cout << "A(ri+1,ci) =" << std::endl;
  50. std::cout << B << std::endl << std::endl;
  51. #if EIGEN_COMP_CXXVER >= 11
  52. B = mat_indexing(A, ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3));
  53. std::cout << "A(ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3)) =" << std::endl;
  54. std::cout << B << std::endl << std::endl;
  55. #endif
  56. std::cout << "[main2]\n";
  57. }