#pragma once namespace at { namespace native { // (Const)StridedRandomAccessor is a // (const) random access iterator defined over // a strided array. // The traits below are to introduce __restrict__ // modifier on different platforms. template struct DefaultPtrTraits { using PtrType = T*; }; #if (defined(_WIN32) || defined(_WIN64)) #define RESTRICT __restrict #else #define RESTRICT __restrict__ #endif template struct RestrictPtrTraits { using PtrType = T* RESTRICT; }; template < typename T, typename index_t = int64_t, template class PtrTraits = DefaultPtrTraits > class ConstStridedRandomAccessor { public: using difference_type = index_t; using value_type = const T; using pointer = const typename PtrTraits::PtrType; using reference = const value_type&; using iterator_category = std::random_access_iterator_tag; using PtrType = typename PtrTraits::PtrType; using index_type = index_t; // Constructors { C10_HOST_DEVICE ConstStridedRandomAccessor(PtrType ptr, index_t stride) : ptr{ptr}, stride{stride} {} C10_HOST_DEVICE explicit ConstStridedRandomAccessor(PtrType ptr) : ptr{ptr}, stride{static_cast(1)} {} C10_HOST_DEVICE ConstStridedRandomAccessor() : ptr{nullptr}, stride{static_cast(1)} {} // } // Pointer-like operations { C10_HOST_DEVICE reference operator*() const { return *ptr; } C10_HOST_DEVICE const value_type* operator->() const { return reinterpret_cast(ptr); } C10_HOST_DEVICE reference operator[](index_t idx) const { return ptr[idx * stride]; } // } // Prefix/postfix increment/decrement { C10_HOST_DEVICE ConstStridedRandomAccessor& operator++() { ptr += stride; return *this; } C10_HOST_DEVICE ConstStridedRandomAccessor operator++(int) { ConstStridedRandomAccessor copy(*this); ++*this; return copy; } C10_HOST_DEVICE ConstStridedRandomAccessor& operator--() { ptr -= stride; return *this; } C10_HOST_DEVICE ConstStridedRandomAccessor operator--(int) { ConstStridedRandomAccessor copy(*this); --*this; return copy; } // } // Arithmetic operations { C10_HOST_DEVICE ConstStridedRandomAccessor& operator+=(index_t offset) { ptr += offset * stride; return *this; } C10_HOST_DEVICE ConstStridedRandomAccessor operator+(index_t offset) const { return ConstStridedRandomAccessor(ptr + offset * stride, stride); } C10_HOST_DEVICE friend ConstStridedRandomAccessor operator+( index_t offset, const ConstStridedRandomAccessor& accessor ) { return accessor + offset; } C10_HOST_DEVICE ConstStridedRandomAccessor& operator-=(index_t offset) { ptr -= offset * stride; return *this; } C10_HOST_DEVICE ConstStridedRandomAccessor operator-(index_t offset) const { return ConstStridedRandomAccessor(ptr - offset * stride, stride); } // Note that this operator is well-defined when `this` and `other` // represent the same sequences, i.e. when // 1. this.stride == other.stride, // 2. |other - this| / this.stride is an Integer. C10_HOST_DEVICE difference_type operator-(const ConstStridedRandomAccessor& other) const { return (ptr - other.ptr) / stride; } // } // Comparison operators { C10_HOST_DEVICE bool operator==(const ConstStridedRandomAccessor& other) const { return (ptr == other.ptr) && (stride == other.stride); } C10_HOST_DEVICE bool operator!=(const ConstStridedRandomAccessor& other) const { return !(*this == other); } C10_HOST_DEVICE bool operator<(const ConstStridedRandomAccessor& other) const { return ptr < other.ptr; } C10_HOST_DEVICE bool operator<=(const ConstStridedRandomAccessor& other) const { return (*this < other) || (*this == other); } C10_HOST_DEVICE bool operator>(const ConstStridedRandomAccessor& other) const { return !(*this <= other); } C10_HOST_DEVICE bool operator>=(const ConstStridedRandomAccessor& other) const { return !(*this < other); } // } protected: PtrType ptr; index_t stride; }; template < typename T, typename index_t = int64_t, template class PtrTraits = DefaultPtrTraits > class StridedRandomAccessor : public ConstStridedRandomAccessor { public: using difference_type = index_t; using value_type = T; using pointer = typename PtrTraits::PtrType; using reference = value_type&; using BaseType = ConstStridedRandomAccessor; using PtrType = typename PtrTraits::PtrType; // Constructors { C10_HOST_DEVICE StridedRandomAccessor(PtrType ptr, index_t stride) : BaseType(ptr, stride) {} C10_HOST_DEVICE explicit StridedRandomAccessor(PtrType ptr) : BaseType(ptr) {} C10_HOST_DEVICE StridedRandomAccessor() : BaseType() {} // } // Pointer-like operations { C10_HOST_DEVICE reference operator*() const { return *this->ptr; } C10_HOST_DEVICE value_type* operator->() const { return reinterpret_cast(this->ptr); } C10_HOST_DEVICE reference operator[](index_t idx) const { return this->ptr[idx * this->stride]; } // } // Prefix/postfix increment/decrement { C10_HOST_DEVICE StridedRandomAccessor& operator++() { this->ptr += this->stride; return *this; } C10_HOST_DEVICE StridedRandomAccessor operator++(int) { StridedRandomAccessor copy(*this); ++*this; return copy; } C10_HOST_DEVICE StridedRandomAccessor& operator--() { this->ptr -= this->stride; return *this; } C10_HOST_DEVICE StridedRandomAccessor operator--(int) { StridedRandomAccessor copy(*this); --*this; return copy; } // } // Arithmetic operations { C10_HOST_DEVICE StridedRandomAccessor& operator+=(index_t offset) { this->ptr += offset * this->stride; return *this; } C10_HOST_DEVICE StridedRandomAccessor operator+(index_t offset) const { return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride); } C10_HOST_DEVICE friend StridedRandomAccessor operator+( index_t offset, const StridedRandomAccessor& accessor ) { return accessor + offset; } C10_HOST_DEVICE StridedRandomAccessor& operator-=(index_t offset) { this->ptr -= offset * this->stride; return *this; } C10_HOST_DEVICE StridedRandomAccessor operator-(index_t offset) const { return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride); } // Note that here we call BaseType::operator- version C10_HOST_DEVICE difference_type operator-(const BaseType& other) const { return (static_cast(*this) - other); } // } }; }} // namespace at::native