123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- //
- // Copyright (c) 2018-2019, Cem Bassoy, cem.bassoy@gmail.com
- //
- // Distributed under the Boost Software License, Version 1.0. (See
- // accompanying file LICENSE_1_0.txt or copy at
- // http://www.boost.org/LICENSE_1_0.txt)
- //
- // The authors gratefully acknowledge the support of
- // Fraunhofer IOSB, Ettlingen, Germany
- //
- #ifndef BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
- #define BOOST_NUMERIC_UBLAS_TENSOR_EXTENTS_HPP
- #include <algorithm>
- #include <initializer_list>
- #include <limits>
- #include <numeric>
- #include <stdexcept>
- #include <vector>
- #include <cassert>
- namespace boost {
- namespace numeric {
- namespace ublas {
- /** @brief Template class for storing tensor extents with runtime variable size.
- *
- * Proxy template class of std::vector<int_type>.
- *
- */
- template<class int_type>
- class basic_extents
- {
- static_assert( std::numeric_limits<typename std::vector<int_type>::value_type>::is_integer, "Static error in basic_layout: type must be of type integer.");
- static_assert(!std::numeric_limits<typename std::vector<int_type>::value_type>::is_signed, "Static error in basic_layout: type must be of type unsigned integer.");
- public:
- using base_type = std::vector<int_type>;
- using value_type = typename base_type::value_type;
- using const_reference = typename base_type::const_reference;
- using reference = typename base_type::reference;
- using size_type = typename base_type::size_type;
- using const_pointer = typename base_type::const_pointer;
- using const_iterator = typename base_type::const_iterator;
- /** @brief Default constructs basic_extents
- *
- * @code auto ex = basic_extents<unsigned>{};
- */
- constexpr explicit basic_extents()
- : _base{}
- {
- }
- /** @brief Copy constructs basic_extents from a one-dimensional container
- *
- * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) );
- *
- * @note checks if size > 1 and all elements > 0
- *
- * @param b one-dimensional std::vector<int_type> container
- */
- explicit basic_extents(base_type const& b)
- : _base(b)
- {
- if (!this->valid()){
- throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
- }
- }
- /** @brief Move constructs basic_extents from a one-dimensional container
- *
- * @code auto ex = basic_extents<unsigned>( std::vector<unsigned>(3u,3u) );
- *
- * @note checks if size > 1 and all elements > 0
- *
- * @param b one-dimensional container of type std::vector<int_type>
- */
- explicit basic_extents(base_type && b)
- : _base(std::move(b))
- {
- if (!this->valid()){
- throw std::length_error("Error in basic_extents::basic_extents() : shape tuple is not a valid permutation: has zero elements.");
- }
- }
- /** @brief Constructs basic_extents from an initializer list
- *
- * @code auto ex = basic_extents<unsigned>{3,2,4};
- *
- * @note checks if size > 1 and all elements > 0
- *
- * @param l one-dimensional list of type std::initializer<int_type>
- */
- basic_extents(std::initializer_list<value_type> l)
- : basic_extents( base_type(std::move(l)) )
- {
- }
- /** @brief Constructs basic_extents from a range specified by two iterators
- *
- * @code auto ex = basic_extents<unsigned>(a.begin(), a.end());
- *
- * @note checks if size > 1 and all elements > 0
- *
- * @param first iterator pointing to the first element
- * @param last iterator pointing to the next position after the last element
- */
- basic_extents(const_iterator first, const_iterator last)
- : basic_extents ( base_type( first,last ) )
- {
- }
- /** @brief Copy constructs basic_extents */
- basic_extents(basic_extents const& l )
- : _base(l._base)
- {
- }
- /** @brief Move constructs basic_extents */
- basic_extents(basic_extents && l ) noexcept
- : _base(std::move(l._base))
- {
- }
- ~basic_extents() = default;
- basic_extents& operator=(basic_extents other) noexcept
- {
- swap (*this, other);
- return *this;
- }
- friend void swap(basic_extents& lhs, basic_extents& rhs) {
- std::swap(lhs._base , rhs._base );
- }
- /** @brief Returns true if this has a scalar shape
- *
- * @returns true if (1,1,[1,...,1])
- */
- bool is_scalar() const
- {
- return
- _base.size() != 0 &&
- std::all_of(_base.begin(), _base.end(),
- [](const_reference a){ return a == 1;});
- }
- /** @brief Returns true if this has a vector shape
- *
- * @returns true if (1,n,[1,...,1]) or (n,1,[1,...,1]) with n > 1
- */
- bool is_vector() const
- {
- if(_base.size() == 0){
- return false;
- }
- if(_base.size() == 1){
- return _base.at(0) > 1;
- }
- auto greater_one = [](const_reference a){ return a > 1;};
- auto equal_one = [](const_reference a){ return a == 1;};
- return
- std::any_of(_base.begin(), _base.begin()+2, greater_one) &&
- std::any_of(_base.begin(), _base.begin()+2, equal_one ) &&
- std::all_of(_base.begin()+2, _base.end(), equal_one);
- }
- /** @brief Returns true if this has a matrix shape
- *
- * @returns true if (m,n,[1,...,1]) with m > 1 and n > 1
- */
- bool is_matrix() const
- {
- if(_base.size() < 2){
- return false;
- }
- auto greater_one = [](const_reference a){ return a > 1;};
- auto equal_one = [](const_reference a){ return a == 1;};
- return
- std::all_of(_base.begin(), _base.begin()+2, greater_one) &&
- std::all_of(_base.begin()+2, _base.end(), equal_one );
- }
- /** @brief Returns true if this is has a tensor shape
- *
- * @returns true if !empty() && !is_scalar() && !is_vector() && !is_matrix()
- */
- bool is_tensor() const
- {
- if(_base.size() < 3){
- return false;
- }
- auto greater_one = [](const_reference a){ return a > 1;};
- return std::any_of(_base.begin()+2, _base.end(), greater_one);
- }
- const_pointer data() const
- {
- return this->_base.data();
- }
- const_reference operator[] (size_type p) const
- {
- return this->_base[p];
- }
- const_reference at (size_type p) const
- {
- return this->_base.at(p);
- }
- reference operator[] (size_type p)
- {
- return this->_base[p];
- }
- reference at (size_type p)
- {
- return this->_base.at(p);
- }
- bool empty() const
- {
- return this->_base.empty();
- }
- size_type size() const
- {
- return this->_base.size();
- }
- /** @brief Returns true if size > 1 and all elements > 0 */
- bool valid() const
- {
- return
- this->size() > 1 &&
- std::none_of(_base.begin(), _base.end(),
- [](const_reference a){ return a == value_type(0); });
- }
- /** @brief Returns the number of elements a tensor holds with this */
- size_type product() const
- {
- if(_base.empty()){
- return 0;
- }
- return std::accumulate(_base.begin(), _base.end(), 1ul, std::multiplies<>());
- }
- /** @brief Eliminates singleton dimensions when size > 2
- *
- * squeeze { 1,1} -> { 1,1}
- * squeeze { 2,1} -> { 2,1}
- * squeeze { 1,2} -> { 1,2}
- *
- * squeeze {1,2,3} -> { 2,3}
- * squeeze {2,1,3} -> { 2,3}
- * squeeze {1,3,1} -> { 3,1}
- *
- */
- basic_extents squeeze() const
- {
- if(this->size() <= 2){
- return *this;
- }
- auto new_extent = basic_extents{};
- auto insert_iter = std::back_insert_iterator<typename basic_extents::base_type>(new_extent._base);
- std::remove_copy(this->_base.begin(), this->_base.end(), insert_iter ,value_type{1});
- return new_extent;
- }
- void clear()
- {
- this->_base.clear();
- }
- bool operator == (basic_extents const& b) const
- {
- return _base == b._base;
- }
- bool operator != (basic_extents const& b) const
- {
- return !( _base == b._base );
- }
- const_iterator
- begin() const
- {
- return _base.begin();
- }
- const_iterator
- end() const
- {
- return _base.end();
- }
- base_type const& base() const { return _base; }
- private:
- base_type _base;
- };
- using shape = basic_extents<std::size_t>;
- } // namespace ublas
- } // namespace numeric
- } // namespace boost
- #endif
|