123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- #pragma once
- namespace at { namespace native {
- // (Const)StridedRandomAccessor is a
- // (const) random access iterator defined over
- // a strided array.
- // The traits below are to introduce __restrict__
- // modifier on different platforms.
- template <typename T>
- struct DefaultPtrTraits {
- using PtrType = T*;
- };
- #if (defined(_WIN32) || defined(_WIN64))
- #define RESTRICT __restrict
- #else
- #define RESTRICT __restrict__
- #endif
- template <typename T>
- struct RestrictPtrTraits {
- using PtrType = T* RESTRICT;
- };
- template <
- typename T,
- typename index_t = int64_t,
- template <typename U> class PtrTraits = DefaultPtrTraits
- >
- class ConstStridedRandomAccessor {
- public:
- using difference_type = index_t;
- using value_type = const T;
- using pointer = const typename PtrTraits<T>::PtrType;
- using reference = const value_type&;
- using iterator_category = std::random_access_iterator_tag;
- using PtrType = typename PtrTraits<T>::PtrType;
- using index_type = index_t;
- // Constructors {
- C10_HOST_DEVICE
- ConstStridedRandomAccessor(PtrType ptr, index_t stride)
- : ptr{ptr}, stride{stride}
- {}
- C10_HOST_DEVICE
- explicit ConstStridedRandomAccessor(PtrType ptr)
- : ptr{ptr}, stride{static_cast<index_t>(1)}
- {}
- C10_HOST_DEVICE
- ConstStridedRandomAccessor()
- : ptr{nullptr}, stride{static_cast<index_t>(1)}
- {}
- // }
- // Pointer-like operations {
- C10_HOST_DEVICE
- reference operator*() const {
- return *ptr;
- }
- C10_HOST_DEVICE
- const value_type* operator->() const {
- return reinterpret_cast<const value_type*>(ptr);
- }
- C10_HOST_DEVICE
- reference operator[](index_t idx) const {
- return ptr[idx * stride];
- }
- // }
- // Prefix/postfix increment/decrement {
- C10_HOST_DEVICE
- ConstStridedRandomAccessor& operator++() {
- ptr += stride;
- return *this;
- }
- C10_HOST_DEVICE
- ConstStridedRandomAccessor operator++(int) {
- ConstStridedRandomAccessor copy(*this);
- ++*this;
- return copy;
- }
- C10_HOST_DEVICE
- ConstStridedRandomAccessor& operator--() {
- ptr -= stride;
- return *this;
- }
- C10_HOST_DEVICE
- ConstStridedRandomAccessor operator--(int) {
- ConstStridedRandomAccessor copy(*this);
- --*this;
- return copy;
- }
- // }
- // Arithmetic operations {
- C10_HOST_DEVICE
- ConstStridedRandomAccessor& operator+=(index_t offset) {
- ptr += offset * stride;
- return *this;
- }
- C10_HOST_DEVICE
- ConstStridedRandomAccessor operator+(index_t offset) const {
- return ConstStridedRandomAccessor(ptr + offset * stride, stride);
- }
- C10_HOST_DEVICE
- friend ConstStridedRandomAccessor operator+(
- index_t offset,
- const ConstStridedRandomAccessor& accessor
- ) {
- return accessor + offset;
- }
- C10_HOST_DEVICE
- ConstStridedRandomAccessor& operator-=(index_t offset) {
- ptr -= offset * stride;
- return *this;
- }
- C10_HOST_DEVICE
- ConstStridedRandomAccessor operator-(index_t offset) const {
- return ConstStridedRandomAccessor(ptr - offset * stride, stride);
- }
- // Note that this operator is well-defined when `this` and `other`
- // represent the same sequences, i.e. when
- // 1. this.stride == other.stride,
- // 2. |other - this| / this.stride is an Integer.
- C10_HOST_DEVICE
- difference_type operator-(const ConstStridedRandomAccessor& other) const {
- return (ptr - other.ptr) / stride;
- }
- // }
- // Comparison operators {
- C10_HOST_DEVICE
- bool operator==(const ConstStridedRandomAccessor& other) const {
- return (ptr == other.ptr) && (stride == other.stride);
- }
- C10_HOST_DEVICE
- bool operator!=(const ConstStridedRandomAccessor& other) const {
- return !(*this == other);
- }
- C10_HOST_DEVICE
- bool operator<(const ConstStridedRandomAccessor& other) const {
- return ptr < other.ptr;
- }
- C10_HOST_DEVICE
- bool operator<=(const ConstStridedRandomAccessor& other) const {
- return (*this < other) || (*this == other);
- }
- C10_HOST_DEVICE
- bool operator>(const ConstStridedRandomAccessor& other) const {
- return !(*this <= other);
- }
- C10_HOST_DEVICE
- bool operator>=(const ConstStridedRandomAccessor& other) const {
- return !(*this < other);
- }
- // }
- protected:
- PtrType ptr;
- index_t stride;
- };
- template <
- typename T,
- typename index_t = int64_t,
- template <typename U> class PtrTraits = DefaultPtrTraits
- >
- class StridedRandomAccessor
- : public ConstStridedRandomAccessor<T, index_t, PtrTraits> {
- public:
- using difference_type = index_t;
- using value_type = T;
- using pointer = typename PtrTraits<T>::PtrType;
- using reference = value_type&;
- using BaseType = ConstStridedRandomAccessor<T, index_t, PtrTraits>;
- using PtrType = typename PtrTraits<T>::PtrType;
- // Constructors {
- C10_HOST_DEVICE
- StridedRandomAccessor(PtrType ptr, index_t stride)
- : BaseType(ptr, stride)
- {}
- C10_HOST_DEVICE
- explicit StridedRandomAccessor(PtrType ptr)
- : BaseType(ptr)
- {}
- C10_HOST_DEVICE
- StridedRandomAccessor()
- : BaseType()
- {}
- // }
- // Pointer-like operations {
- C10_HOST_DEVICE
- reference operator*() const {
- return *this->ptr;
- }
- C10_HOST_DEVICE
- value_type* operator->() const {
- return reinterpret_cast<value_type*>(this->ptr);
- }
- C10_HOST_DEVICE
- reference operator[](index_t idx) const {
- return this->ptr[idx * this->stride];
- }
- // }
- // Prefix/postfix increment/decrement {
- C10_HOST_DEVICE
- StridedRandomAccessor& operator++() {
- this->ptr += this->stride;
- return *this;
- }
- C10_HOST_DEVICE
- StridedRandomAccessor operator++(int) {
- StridedRandomAccessor copy(*this);
- ++*this;
- return copy;
- }
- C10_HOST_DEVICE
- StridedRandomAccessor& operator--() {
- this->ptr -= this->stride;
- return *this;
- }
- C10_HOST_DEVICE
- StridedRandomAccessor operator--(int) {
- StridedRandomAccessor copy(*this);
- --*this;
- return copy;
- }
- // }
- // Arithmetic operations {
- C10_HOST_DEVICE
- StridedRandomAccessor& operator+=(index_t offset) {
- this->ptr += offset * this->stride;
- return *this;
- }
- C10_HOST_DEVICE
- StridedRandomAccessor operator+(index_t offset) const {
- return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride);
- }
- C10_HOST_DEVICE
- friend StridedRandomAccessor operator+(
- index_t offset,
- const StridedRandomAccessor& accessor
- ) {
- return accessor + offset;
- }
- C10_HOST_DEVICE
- StridedRandomAccessor& operator-=(index_t offset) {
- this->ptr -= offset * this->stride;
- return *this;
- }
- C10_HOST_DEVICE
- StridedRandomAccessor operator-(index_t offset) const {
- return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride);
- }
- // Note that here we call BaseType::operator- version
- C10_HOST_DEVICE
- difference_type operator-(const BaseType& other) const {
- return (static_cast<const BaseType&>(*this) - other);
- }
- // }
- };
- }} // namespace at::native
|