CompositeRandomAccessorCommon.h 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #include <utility>
  2. #pragma once
  3. namespace at { namespace native {
  4. namespace {
  5. // operator_brackets_proxy is used in
  6. // CompositeRandomAccessor in place of operator[].
  7. // For some iterators, references returned by operator[]
  8. // could become invalid, operator_brackets_proxy tries to
  9. // resolve that by making accessor[n] to be equivalent to
  10. // *(accessor + n).
  11. template <typename Accessor>
  12. class operator_brackets_proxy {
  13. using reference = typename std::iterator_traits<Accessor>::reference;
  14. using value_type = typename std::iterator_traits<Accessor>::value_type;
  15. public:
  16. C10_HOST_DEVICE
  17. operator_brackets_proxy(Accessor const& accessor)
  18. : accessor(accessor)
  19. {}
  20. C10_HOST_DEVICE
  21. operator reference() {
  22. return *accessor;
  23. }
  24. C10_HOST_DEVICE
  25. reference operator*() {
  26. return *accessor;
  27. }
  28. C10_HOST_DEVICE
  29. operator_brackets_proxy& operator=(value_type const& val) {
  30. *accessor = val;
  31. return *this;
  32. }
  33. private:
  34. Accessor accessor;
  35. };
  36. }
  37. // references_holder is used as a surrogate for the
  38. // references type from std::iterator_traits in CompositeRandomAccessor.
  39. // It is assumed in CompositeRandomAccessor that
  40. // References = tuple<Types&...>,
  41. // Values = tuple<Types...> by default,
  42. // but they could be anything as long as References could be
  43. // cast to Values.
  44. // If you plan to use it with STL, for example, you will need to
  45. // define 'swap` and `get`(aka std::get) methods.
  46. template <typename Values, typename References>
  47. class references_holder {
  48. public:
  49. using values = Values;
  50. using references = References;
  51. C10_HOST_DEVICE
  52. references_holder(references refs)
  53. : refs{std::move(refs)}
  54. {}
  55. C10_HOST_DEVICE
  56. operator references() {
  57. return refs;
  58. }
  59. C10_HOST_DEVICE
  60. operator values() {
  61. return refs;
  62. }
  63. C10_HOST_DEVICE
  64. references_holder& operator=(values vals) {
  65. refs = vals;
  66. return *this;
  67. }
  68. C10_HOST_DEVICE
  69. references& data() {
  70. return refs;
  71. }
  72. protected:
  73. references refs;
  74. };
  75. // CompositeRandomAccessor is essentially a simplified version of
  76. // a random access iterator over two random access iterators.
  77. // TupleInfo should contain a variadic type `tuple`, and a method `tie`,
  78. // which constructs a tuple of references from a variadic list of arguments.
  79. template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
  80. class CompositeRandomAccessor {
  81. using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
  82. using key_accessor_value_type =
  83. typename std::iterator_traits<KeyAccessor>::value_type;
  84. using value_accessor_value_type =
  85. typename std::iterator_traits<ValueAccessor>::value_type;
  86. using key_accessor_reference_type =
  87. typename std::iterator_traits<KeyAccessor>::reference;
  88. using value_accessor_reference_type =
  89. typename std::iterator_traits<ValueAccessor>::reference;
  90. using composite_value_type = typename TupleInfo::template tuple<
  91. key_accessor_value_type,
  92. value_accessor_value_type>;
  93. using composite_reference = typename TupleInfo::template tuple<
  94. key_accessor_reference_type,
  95. value_accessor_reference_type>;
  96. public:
  97. using value_type = composite_value_type;
  98. using reference = references_holder<composite_value_type, composite_reference>;
  99. // Note that CompositeRandomAccessor does not hold key and values
  100. // in a specific datastrcture, which means that a pointer to a (key, value)
  101. // is not defined. Hence we just use a pointer type of the KeyAccessor.
  102. using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
  103. using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
  104. using iterator_category = std::random_access_iterator_tag;
  105. C10_HOST_DEVICE
  106. CompositeRandomAccessor() = default;
  107. C10_HOST_DEVICE
  108. CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
  109. : keys(keys), values(values)
  110. {}
  111. // Pointer-like operations {
  112. C10_HOST_DEVICE
  113. reference operator*() const {
  114. return TupleInfo::tie(*keys, *values);
  115. }
  116. // operator->() is supposed to return a pointer type.
  117. // Since CompositeRandomAccessor does not hold pointers to pairs,
  118. // we just return a pointer to a key.
  119. C10_HOST_DEVICE
  120. auto* operator->() const {
  121. return keys.operator->();
  122. }
  123. C10_HOST_DEVICE
  124. reference operator[](difference_type idx) {
  125. return operator_brackets_proxy<self_type>(
  126. CompositeRandomAccessor(keys + idx, values + idx)
  127. );
  128. }
  129. // }
  130. // Prefix/postfix increment/decrement {
  131. C10_HOST_DEVICE
  132. CompositeRandomAccessor& operator++() {
  133. ++keys;
  134. ++values;
  135. return *this;
  136. }
  137. C10_HOST_DEVICE
  138. CompositeRandomAccessor operator++(int) {
  139. CompositeRandomAccessor copy(*this);
  140. ++*this;
  141. return copy;
  142. }
  143. C10_HOST_DEVICE
  144. CompositeRandomAccessor& operator--() {
  145. --keys;
  146. --values;
  147. return *this;
  148. }
  149. C10_HOST_DEVICE
  150. CompositeRandomAccessor operator--(int) {
  151. CompositeRandomAccessor copy(*this);
  152. --*this;
  153. return copy;
  154. }
  155. // }
  156. // Arithmetic operations {
  157. C10_HOST_DEVICE
  158. CompositeRandomAccessor& operator+=(difference_type offset) {
  159. keys += offset;
  160. values += offset;
  161. return *this;
  162. }
  163. C10_HOST_DEVICE
  164. CompositeRandomAccessor operator+(difference_type offset) const {
  165. return CompositeRandomAccessor(keys + offset, values + offset);
  166. }
  167. C10_HOST_DEVICE
  168. friend CompositeRandomAccessor operator+(
  169. difference_type offset,
  170. const CompositeRandomAccessor& accessor
  171. ) {
  172. return accessor + offset;
  173. }
  174. C10_HOST_DEVICE
  175. CompositeRandomAccessor& operator-=(difference_type offset) {
  176. keys -= offset;
  177. values -= offset;
  178. return *this;
  179. }
  180. C10_HOST_DEVICE
  181. CompositeRandomAccessor operator-(difference_type offset) const {
  182. return CompositeRandomAccessor(keys - offset, values - offset);
  183. }
  184. C10_HOST_DEVICE
  185. difference_type operator-(const CompositeRandomAccessor& other) const {
  186. return keys - other.keys;
  187. }
  188. // }
  189. // Comparison operators {
  190. C10_HOST_DEVICE
  191. bool operator==(const CompositeRandomAccessor& other) const {
  192. return keys == other.keys;
  193. }
  194. C10_HOST_DEVICE
  195. bool operator!=(const CompositeRandomAccessor& other) const {
  196. return keys != other.keys;
  197. }
  198. C10_HOST_DEVICE
  199. bool operator<(const CompositeRandomAccessor& other) const {
  200. return keys < other.keys;
  201. }
  202. C10_HOST_DEVICE
  203. bool operator<=(const CompositeRandomAccessor& other) const {
  204. return keys <= other.keys;
  205. }
  206. C10_HOST_DEVICE
  207. bool operator>(const CompositeRandomAccessor& other) const {
  208. return keys > other.keys;
  209. }
  210. C10_HOST_DEVICE
  211. bool operator>=(const CompositeRandomAccessor& other) const {
  212. return keys >= other.keys;
  213. }
  214. // }
  215. protected:
  216. KeyAccessor keys;
  217. ValueAccessor values;
  218. };
  219. }} // namespace at::native