123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- #include <utility>
- #pragma once
- namespace at { namespace native {
- namespace {
- // operator_brackets_proxy is used in
- // CompositeRandomAccessor in place of operator[].
- // For some iterators, references returned by operator[]
- // could become invalid, operator_brackets_proxy tries to
- // resolve that by making accessor[n] to be equivalent to
- // *(accessor + n).
- template <typename Accessor>
- class operator_brackets_proxy {
- using reference = typename std::iterator_traits<Accessor>::reference;
- using value_type = typename std::iterator_traits<Accessor>::value_type;
- public:
- C10_HOST_DEVICE
- operator_brackets_proxy(Accessor const& accessor)
- : accessor(accessor)
- {}
- C10_HOST_DEVICE
- operator reference() {
- return *accessor;
- }
- C10_HOST_DEVICE
- reference operator*() {
- return *accessor;
- }
- C10_HOST_DEVICE
- operator_brackets_proxy& operator=(value_type const& val) {
- *accessor = val;
- return *this;
- }
- private:
- Accessor accessor;
- };
- }
- // references_holder is used as a surrogate for the
- // references type from std::iterator_traits in CompositeRandomAccessor.
- // It is assumed in CompositeRandomAccessor that
- // References = tuple<Types&...>,
- // Values = tuple<Types...> by default,
- // but they could be anything as long as References could be
- // cast to Values.
- // If you plan to use it with STL, for example, you will need to
- // define 'swap` and `get`(aka std::get) methods.
- template <typename Values, typename References>
- class references_holder {
- public:
- using values = Values;
- using references = References;
- C10_HOST_DEVICE
- references_holder(references refs)
- : refs{std::move(refs)}
- {}
- C10_HOST_DEVICE
- operator references() {
- return refs;
- }
- C10_HOST_DEVICE
- operator values() {
- return refs;
- }
- C10_HOST_DEVICE
- references_holder& operator=(values vals) {
- refs = vals;
- return *this;
- }
- C10_HOST_DEVICE
- references& data() {
- return refs;
- }
- protected:
- references refs;
- };
- // CompositeRandomAccessor is essentially a simplified version of
- // a random access iterator over two random access iterators.
- // TupleInfo should contain a variadic type `tuple`, and a method `tie`,
- // which constructs a tuple of references from a variadic list of arguments.
- template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
- class CompositeRandomAccessor {
- using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
- using key_accessor_value_type =
- typename std::iterator_traits<KeyAccessor>::value_type;
- using value_accessor_value_type =
- typename std::iterator_traits<ValueAccessor>::value_type;
- using key_accessor_reference_type =
- typename std::iterator_traits<KeyAccessor>::reference;
- using value_accessor_reference_type =
- typename std::iterator_traits<ValueAccessor>::reference;
- using composite_value_type = typename TupleInfo::template tuple<
- key_accessor_value_type,
- value_accessor_value_type>;
- using composite_reference = typename TupleInfo::template tuple<
- key_accessor_reference_type,
- value_accessor_reference_type>;
- public:
- using value_type = composite_value_type;
- using reference = references_holder<composite_value_type, composite_reference>;
- // Note that CompositeRandomAccessor does not hold key and values
- // in a specific datastrcture, which means that a pointer to a (key, value)
- // is not defined. Hence we just use a pointer type of the KeyAccessor.
- using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
- using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
- using iterator_category = std::random_access_iterator_tag;
- C10_HOST_DEVICE
- CompositeRandomAccessor() = default;
- C10_HOST_DEVICE
- CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
- : keys(keys), values(values)
- {}
- // Pointer-like operations {
- C10_HOST_DEVICE
- reference operator*() const {
- return TupleInfo::tie(*keys, *values);
- }
- // operator->() is supposed to return a pointer type.
- // Since CompositeRandomAccessor does not hold pointers to pairs,
- // we just return a pointer to a key.
- C10_HOST_DEVICE
- auto* operator->() const {
- return keys.operator->();
- }
- C10_HOST_DEVICE
- reference operator[](difference_type idx) {
- return operator_brackets_proxy<self_type>(
- CompositeRandomAccessor(keys + idx, values + idx)
- );
- }
- // }
- // Prefix/postfix increment/decrement {
- C10_HOST_DEVICE
- CompositeRandomAccessor& operator++() {
- ++keys;
- ++values;
- return *this;
- }
- C10_HOST_DEVICE
- CompositeRandomAccessor operator++(int) {
- CompositeRandomAccessor copy(*this);
- ++*this;
- return copy;
- }
- C10_HOST_DEVICE
- CompositeRandomAccessor& operator--() {
- --keys;
- --values;
- return *this;
- }
- C10_HOST_DEVICE
- CompositeRandomAccessor operator--(int) {
- CompositeRandomAccessor copy(*this);
- --*this;
- return copy;
- }
- // }
- // Arithmetic operations {
- C10_HOST_DEVICE
- CompositeRandomAccessor& operator+=(difference_type offset) {
- keys += offset;
- values += offset;
- return *this;
- }
- C10_HOST_DEVICE
- CompositeRandomAccessor operator+(difference_type offset) const {
- return CompositeRandomAccessor(keys + offset, values + offset);
- }
- C10_HOST_DEVICE
- friend CompositeRandomAccessor operator+(
- difference_type offset,
- const CompositeRandomAccessor& accessor
- ) {
- return accessor + offset;
- }
- C10_HOST_DEVICE
- CompositeRandomAccessor& operator-=(difference_type offset) {
- keys -= offset;
- values -= offset;
- return *this;
- }
- C10_HOST_DEVICE
- CompositeRandomAccessor operator-(difference_type offset) const {
- return CompositeRandomAccessor(keys - offset, values - offset);
- }
- C10_HOST_DEVICE
- difference_type operator-(const CompositeRandomAccessor& other) const {
- return keys - other.keys;
- }
- // }
- // Comparison operators {
- C10_HOST_DEVICE
- bool operator==(const CompositeRandomAccessor& other) const {
- return keys == other.keys;
- }
- C10_HOST_DEVICE
- bool operator!=(const CompositeRandomAccessor& other) const {
- return keys != other.keys;
- }
- C10_HOST_DEVICE
- bool operator<(const CompositeRandomAccessor& other) const {
- return keys < other.keys;
- }
- C10_HOST_DEVICE
- bool operator<=(const CompositeRandomAccessor& other) const {
- return keys <= other.keys;
- }
- C10_HOST_DEVICE
- bool operator>(const CompositeRandomAccessor& other) const {
- return keys > other.keys;
- }
- C10_HOST_DEVICE
- bool operator>=(const CompositeRandomAccessor& other) const {
- return keys >= other.keys;
- }
- // }
- protected:
- KeyAccessor keys;
- ValueAccessor values;
- };
- }} // namespace at::native
|