ArrayRef.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. //===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. // ATen: modified from llvm::ArrayRef.
  10. // removed llvm-specific functionality
  11. // removed some implicit const -> non-const conversions that rely on
  12. // complicated std::enable_if meta-programming
  13. // removed a bunch of slice variants for simplicity...
  14. #pragma once
  15. #include <c10/util/C++17.h>
  16. #include <c10/util/Deprecated.h>
  17. #include <c10/util/Exception.h>
  18. #include <c10/util/SmallVector.h>
  19. #include <array>
  20. #include <iterator>
  21. #include <vector>
  22. namespace c10 {
  23. /// ArrayRef - Represent a constant reference to an array (0 or more elements
  24. /// consecutively in memory), i.e. a start pointer and a length. It allows
  25. /// various APIs to take consecutive elements easily and conveniently.
  26. ///
  27. /// This class does not own the underlying data, it is expected to be used in
  28. /// situations where the data resides in some other buffer, whose lifetime
  29. /// extends past that of the ArrayRef. For this reason, it is not in general
  30. /// safe to store an ArrayRef.
  31. ///
  32. /// This is intended to be trivially copyable, so it should be passed by
  33. /// value.
  34. template <typename T>
  35. class ArrayRef final {
  36. public:
  37. using iterator = const T*;
  38. using const_iterator = const T*;
  39. using size_type = size_t;
  40. using value_type = T;
  41. using reverse_iterator = std::reverse_iterator<iterator>;
  42. private:
  43. /// The start of the array, in an external buffer.
  44. const T* Data;
  45. /// The number of elements.
  46. size_type Length;
  47. void debugCheckNullptrInvariant() {
  48. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  49. Data != nullptr || Length == 0,
  50. "created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal");
  51. }
  52. public:
  53. /// @name Constructors
  54. /// @{
  55. /// Construct an empty ArrayRef.
  56. /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
  57. /// Construct an ArrayRef from a single element.
  58. // TODO Make this explicit
  59. constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
  60. /// Construct an ArrayRef from a pointer and length.
  61. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length)
  62. : Data(data), Length(length) {
  63. debugCheckNullptrInvariant();
  64. }
  65. /// Construct an ArrayRef from a range.
  66. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end)
  67. : Data(begin), Length(end - begin) {
  68. debugCheckNullptrInvariant();
  69. }
  70. /// Construct an ArrayRef from a SmallVector. This is templated in order to
  71. /// avoid instantiating SmallVectorTemplateCommon<T> whenever we
  72. /// copy-construct an ArrayRef.
  73. template <typename U>
  74. /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
  75. : Data(Vec.data()), Length(Vec.size()) {
  76. debugCheckNullptrInvariant();
  77. }
  78. template <
  79. typename Container,
  80. typename = std::enable_if_t<std::is_same<
  81. std::remove_const_t<decltype(std::declval<Container>().data())>,
  82. T*>::value>>
  83. /* implicit */ ArrayRef(const Container& container)
  84. : Data(container.data()), Length(container.size()) {
  85. debugCheckNullptrInvariant();
  86. }
  87. /// Construct an ArrayRef from a std::vector.
  88. // The enable_if stuff here makes sure that this isn't used for
  89. // std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
  90. // bitfield.
  91. template <typename A>
  92. /* implicit */ ArrayRef(const std::vector<T, A>& Vec)
  93. : Data(Vec.data()), Length(Vec.size()) {
  94. static_assert(
  95. !std::is_same<T, bool>::value,
  96. "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
  97. }
  98. /// Construct an ArrayRef from a std::array
  99. template <size_t N>
  100. /* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
  101. : Data(Arr.data()), Length(N) {}
  102. /// Construct an ArrayRef from a C array.
  103. template <size_t N>
  104. /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
  105. /// Construct an ArrayRef from a std::initializer_list.
  106. /* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
  107. : Data(
  108. std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
  109. : std::begin(Vec)),
  110. Length(Vec.size()) {}
  111. /// @}
  112. /// @name Simple Operations
  113. /// @{
  114. constexpr iterator begin() const {
  115. return Data;
  116. }
  117. constexpr iterator end() const {
  118. return Data + Length;
  119. }
  120. // These are actually the same as iterator, since ArrayRef only
  121. // gives you const iterators.
  122. constexpr const_iterator cbegin() const {
  123. return Data;
  124. }
  125. constexpr const_iterator cend() const {
  126. return Data + Length;
  127. }
  128. constexpr reverse_iterator rbegin() const {
  129. return reverse_iterator(end());
  130. }
  131. constexpr reverse_iterator rend() const {
  132. return reverse_iterator(begin());
  133. }
  134. /// empty - Check if the array is empty.
  135. constexpr bool empty() const {
  136. return Length == 0;
  137. }
  138. constexpr const T* data() const {
  139. return Data;
  140. }
  141. /// size - Get the array size.
  142. constexpr size_t size() const {
  143. return Length;
  144. }
  145. /// front - Get the first element.
  146. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
  147. TORCH_CHECK(
  148. !empty(), "ArrayRef: attempted to access front() of empty list");
  149. return Data[0];
  150. }
  151. /// back - Get the last element.
  152. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const {
  153. TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
  154. return Data[Length - 1];
  155. }
  156. /// equals - Check for element-wise equality.
  157. constexpr bool equals(ArrayRef RHS) const {
  158. return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
  159. }
  160. /// slice(n, m) - Take M elements of the array starting at element N
  161. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
  162. const {
  163. TORCH_CHECK(
  164. N + M <= size(),
  165. "ArrayRef: invalid slice, N = ",
  166. N,
  167. "; M = ",
  168. M,
  169. "; size = ",
  170. size());
  171. return ArrayRef<T>(data() + N, M);
  172. }
  173. /// slice(n) - Chop off the first N elements of the array.
  174. constexpr ArrayRef<T> slice(size_t N) const {
  175. return slice(N, size() - N);
  176. }
  177. /// @}
  178. /// @name Operator Overloads
  179. /// @{
  180. constexpr const T& operator[](size_t Index) const {
  181. return Data[Index];
  182. }
  183. /// Vector compatibility
  184. C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const {
  185. TORCH_CHECK(
  186. Index < Length,
  187. "ArrayRef: invalid index Index = ",
  188. Index,
  189. "; Length = ",
  190. Length);
  191. return Data[Index];
  192. }
  193. /// Disallow accidental assignment from a temporary.
  194. ///
  195. /// The declaration here is extra complicated so that "arrayRef = {}"
  196. /// continues to select the move assignment operator.
  197. template <typename U>
  198. typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
  199. operator=(U&& Temporary) = delete;
  200. /// Disallow accidental assignment from a temporary.
  201. ///
  202. /// The declaration here is extra complicated so that "arrayRef = {}"
  203. /// continues to select the move assignment operator.
  204. template <typename U>
  205. typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
  206. operator=(std::initializer_list<U>) = delete;
  207. /// @}
  208. /// @name Expensive Operations
  209. /// @{
  210. std::vector<T> vec() const {
  211. return std::vector<T>(Data, Data + Length);
  212. }
  213. /// @}
  214. };
  215. template <typename T>
  216. std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
  217. int i = 0;
  218. out << "[";
  219. for (const auto& e : list) {
  220. if (i++ > 0)
  221. out << ", ";
  222. out << e;
  223. }
  224. out << "]";
  225. return out;
  226. }
  227. /// @name ArrayRef Convenience constructors
  228. /// @{
  229. /// Construct an ArrayRef from a single element.
  230. template <typename T>
  231. ArrayRef<T> makeArrayRef(const T& OneElt) {
  232. return OneElt;
  233. }
  234. /// Construct an ArrayRef from a pointer and length.
  235. template <typename T>
  236. ArrayRef<T> makeArrayRef(const T* data, size_t length) {
  237. return ArrayRef<T>(data, length);
  238. }
  239. /// Construct an ArrayRef from a range.
  240. template <typename T>
  241. ArrayRef<T> makeArrayRef(const T* begin, const T* end) {
  242. return ArrayRef<T>(begin, end);
  243. }
  244. /// Construct an ArrayRef from a SmallVector.
  245. template <typename T>
  246. ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) {
  247. return Vec;
  248. }
  249. /// Construct an ArrayRef from a SmallVector.
  250. template <typename T, unsigned N>
  251. ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) {
  252. return Vec;
  253. }
  254. /// Construct an ArrayRef from a std::vector.
  255. template <typename T>
  256. ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) {
  257. return Vec;
  258. }
  259. /// Construct an ArrayRef from a std::array.
  260. template <typename T, std::size_t N>
  261. ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) {
  262. return Arr;
  263. }
  264. /// Construct an ArrayRef from an ArrayRef (no-op) (const)
  265. template <typename T>
  266. ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) {
  267. return Vec;
  268. }
  269. /// Construct an ArrayRef from an ArrayRef (no-op)
  270. template <typename T>
  271. ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) {
  272. return Vec;
  273. }
  274. /// Construct an ArrayRef from a C array.
  275. template <typename T, size_t N>
  276. ArrayRef<T> makeArrayRef(const T (&Arr)[N]) {
  277. return ArrayRef<T>(Arr);
  278. }
  279. // WARNING: Template instantiation will NOT be willing to do an implicit
  280. // conversions to get you to an c10::ArrayRef, which is why we need so
  281. // many overloads.
  282. template <typename T>
  283. bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
  284. return a1.equals(a2);
  285. }
  286. template <typename T>
  287. bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
  288. return !a1.equals(a2);
  289. }
  290. template <typename T>
  291. bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
  292. return c10::ArrayRef<T>(a1).equals(a2);
  293. }
  294. template <typename T>
  295. bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
  296. return !c10::ArrayRef<T>(a1).equals(a2);
  297. }
  298. template <typename T>
  299. bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
  300. return a1.equals(c10::ArrayRef<T>(a2));
  301. }
  302. template <typename T>
  303. bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
  304. return !a1.equals(c10::ArrayRef<T>(a2));
  305. }
  306. using IntArrayRef = ArrayRef<int64_t>;
  307. // This alias is deprecated because it doesn't make ownership
  308. // semantics obvious. Use IntArrayRef instead!
  309. C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef<int64_t>)
  310. } // namespace c10