StridedRandomAccessor.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. #pragma once
  2. namespace at { namespace native {
  3. // (Const)StridedRandomAccessor is a
  4. // (const) random access iterator defined over
  5. // a strided array.
  6. // The traits below are to introduce __restrict__
  7. // modifier on different platforms.
  8. template <typename T>
  9. struct DefaultPtrTraits {
  10. using PtrType = T*;
  11. };
  12. #if (defined(_WIN32) || defined(_WIN64))
  13. #define RESTRICT __restrict
  14. #else
  15. #define RESTRICT __restrict__
  16. #endif
  17. template <typename T>
  18. struct RestrictPtrTraits {
  19. using PtrType = T* RESTRICT;
  20. };
  21. template <
  22. typename T,
  23. typename index_t = int64_t,
  24. template <typename U> class PtrTraits = DefaultPtrTraits
  25. >
  26. class ConstStridedRandomAccessor {
  27. public:
  28. using difference_type = index_t;
  29. using value_type = const T;
  30. using pointer = const typename PtrTraits<T>::PtrType;
  31. using reference = const value_type&;
  32. using iterator_category = std::random_access_iterator_tag;
  33. using PtrType = typename PtrTraits<T>::PtrType;
  34. using index_type = index_t;
  35. // Constructors {
  36. C10_HOST_DEVICE
  37. ConstStridedRandomAccessor(PtrType ptr, index_t stride)
  38. : ptr{ptr}, stride{stride}
  39. {}
  40. C10_HOST_DEVICE
  41. explicit ConstStridedRandomAccessor(PtrType ptr)
  42. : ptr{ptr}, stride{static_cast<index_t>(1)}
  43. {}
  44. C10_HOST_DEVICE
  45. ConstStridedRandomAccessor()
  46. : ptr{nullptr}, stride{static_cast<index_t>(1)}
  47. {}
  48. // }
  49. // Pointer-like operations {
  50. C10_HOST_DEVICE
  51. reference operator*() const {
  52. return *ptr;
  53. }
  54. C10_HOST_DEVICE
  55. const value_type* operator->() const {
  56. return reinterpret_cast<const value_type*>(ptr);
  57. }
  58. C10_HOST_DEVICE
  59. reference operator[](index_t idx) const {
  60. return ptr[idx * stride];
  61. }
  62. // }
  63. // Prefix/postfix increment/decrement {
  64. C10_HOST_DEVICE
  65. ConstStridedRandomAccessor& operator++() {
  66. ptr += stride;
  67. return *this;
  68. }
  69. C10_HOST_DEVICE
  70. ConstStridedRandomAccessor operator++(int) {
  71. ConstStridedRandomAccessor copy(*this);
  72. ++*this;
  73. return copy;
  74. }
  75. C10_HOST_DEVICE
  76. ConstStridedRandomAccessor& operator--() {
  77. ptr -= stride;
  78. return *this;
  79. }
  80. C10_HOST_DEVICE
  81. ConstStridedRandomAccessor operator--(int) {
  82. ConstStridedRandomAccessor copy(*this);
  83. --*this;
  84. return copy;
  85. }
  86. // }
  87. // Arithmetic operations {
  88. C10_HOST_DEVICE
  89. ConstStridedRandomAccessor& operator+=(index_t offset) {
  90. ptr += offset * stride;
  91. return *this;
  92. }
  93. C10_HOST_DEVICE
  94. ConstStridedRandomAccessor operator+(index_t offset) const {
  95. return ConstStridedRandomAccessor(ptr + offset * stride, stride);
  96. }
  97. C10_HOST_DEVICE
  98. friend ConstStridedRandomAccessor operator+(
  99. index_t offset,
  100. const ConstStridedRandomAccessor& accessor
  101. ) {
  102. return accessor + offset;
  103. }
  104. C10_HOST_DEVICE
  105. ConstStridedRandomAccessor& operator-=(index_t offset) {
  106. ptr -= offset * stride;
  107. return *this;
  108. }
  109. C10_HOST_DEVICE
  110. ConstStridedRandomAccessor operator-(index_t offset) const {
  111. return ConstStridedRandomAccessor(ptr - offset * stride, stride);
  112. }
  113. // Note that this operator is well-defined when `this` and `other`
  114. // represent the same sequences, i.e. when
  115. // 1. this.stride == other.stride,
  116. // 2. |other - this| / this.stride is an Integer.
  117. C10_HOST_DEVICE
  118. difference_type operator-(const ConstStridedRandomAccessor& other) const {
  119. return (ptr - other.ptr) / stride;
  120. }
  121. // }
  122. // Comparison operators {
  123. C10_HOST_DEVICE
  124. bool operator==(const ConstStridedRandomAccessor& other) const {
  125. return (ptr == other.ptr) && (stride == other.stride);
  126. }
  127. C10_HOST_DEVICE
  128. bool operator!=(const ConstStridedRandomAccessor& other) const {
  129. return !(*this == other);
  130. }
  131. C10_HOST_DEVICE
  132. bool operator<(const ConstStridedRandomAccessor& other) const {
  133. return ptr < other.ptr;
  134. }
  135. C10_HOST_DEVICE
  136. bool operator<=(const ConstStridedRandomAccessor& other) const {
  137. return (*this < other) || (*this == other);
  138. }
  139. C10_HOST_DEVICE
  140. bool operator>(const ConstStridedRandomAccessor& other) const {
  141. return !(*this <= other);
  142. }
  143. C10_HOST_DEVICE
  144. bool operator>=(const ConstStridedRandomAccessor& other) const {
  145. return !(*this < other);
  146. }
  147. // }
  148. protected:
  149. PtrType ptr;
  150. index_t stride;
  151. };
  152. template <
  153. typename T,
  154. typename index_t = int64_t,
  155. template <typename U> class PtrTraits = DefaultPtrTraits
  156. >
  157. class StridedRandomAccessor
  158. : public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
  159. public:
  160. using difference_type = index_t;
  161. using value_type = T;
  162. using pointer = typename PtrTraits<T>::PtrType;
  163. using reference = value_type&;
  164. using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
  165. using PtrType = typename PtrTraits<T>::PtrType;
  166. // Constructors {
  167. C10_HOST_DEVICE
  168. StridedRandomAccessor(PtrType ptr, index_t stride)
  169. : BaseType(ptr, stride)
  170. {}
  171. C10_HOST_DEVICE
  172. explicit StridedRandomAccessor(PtrType ptr)
  173. : BaseType(ptr)
  174. {}
  175. C10_HOST_DEVICE
  176. StridedRandomAccessor()
  177. : BaseType()
  178. {}
  179. // }
  180. // Pointer-like operations {
  181. C10_HOST_DEVICE
  182. reference operator*() const {
  183. return *this->ptr;
  184. }
  185. C10_HOST_DEVICE
  186. value_type* operator->() const {
  187. return reinterpret_cast<value_type*>(this->ptr);
  188. }
  189. C10_HOST_DEVICE
  190. reference operator[](index_t idx) const {
  191. return this->ptr[idx * this->stride];
  192. }
  193. // }
  194. // Prefix/postfix increment/decrement {
  195. C10_HOST_DEVICE
  196. StridedRandomAccessor& operator++() {
  197. this->ptr += this->stride;
  198. return *this;
  199. }
  200. C10_HOST_DEVICE
  201. StridedRandomAccessor operator++(int) {
  202. StridedRandomAccessor copy(*this);
  203. ++*this;
  204. return copy;
  205. }
  206. C10_HOST_DEVICE
  207. StridedRandomAccessor& operator--() {
  208. this->ptr -= this->stride;
  209. return *this;
  210. }
  211. C10_HOST_DEVICE
  212. StridedRandomAccessor operator--(int) {
  213. StridedRandomAccessor copy(*this);
  214. --*this;
  215. return copy;
  216. }
  217. // }
  218. // Arithmetic operations {
  219. C10_HOST_DEVICE
  220. StridedRandomAccessor& operator+=(index_t offset) {
  221. this->ptr += offset * this->stride;
  222. return *this;
  223. }
  224. C10_HOST_DEVICE
  225. StridedRandomAccessor operator+(index_t offset) const {
  226. return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
  227. }
  228. C10_HOST_DEVICE
  229. friend StridedRandomAccessor operator+(
  230. index_t offset,
  231. const StridedRandomAccessor& accessor
  232. ) {
  233. return accessor + offset;
  234. }
  235. C10_HOST_DEVICE
  236. StridedRandomAccessor& operator-=(index_t offset) {
  237. this->ptr -= offset * this->stride;
  238. return *this;
  239. }
  240. C10_HOST_DEVICE
  241. StridedRandomAccessor operator-(index_t offset) const {
  242. return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
  243. }
  244. // Note that here we call BaseType::operator- version
  245. C10_HOST_DEVICE
  246. difference_type operator-(const BaseType& other) const {
  247. return (static_cast<const BaseType&>(*this) - other);
  248. }
  249. // }
  250. };
  251. }} // namespace at::native