OptionalArrayRef.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. // This file defines OptionalArrayRef<T>, a class that has almost the same
  2. // exact functionality as c10::optional<ArrayRef<T>>, except that its
  3. // converting constructor fixes a dangling pointer issue.
  4. //
  5. // The implicit converting constructor of both c10::optional<ArrayRef<T>> and
  6. // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
  7. // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
  8. // a c10::optional<ArrayRef<T>> and fixing the constructor implementation.
  9. //
  10. // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
  11. #pragma once
  12. #include <c10/util/ArrayRef.h>
  13. #include <c10/util/Optional.h>
  14. namespace c10 {
  15. template <typename T>
  16. class OptionalArrayRef final {
  17. public:
  18. // Constructors
  19. constexpr OptionalArrayRef() noexcept = default;
  20. constexpr OptionalArrayRef(nullopt_t) noexcept {}
  21. OptionalArrayRef(const OptionalArrayRef& other) = default;
  22. OptionalArrayRef(OptionalArrayRef&& other) = default;
  23. constexpr OptionalArrayRef(const optional<ArrayRef<T>>& other) noexcept
  24. : wrapped_opt_array_ref(other) {}
  25. constexpr OptionalArrayRef(optional<ArrayRef<T>>&& other) noexcept
  26. : wrapped_opt_array_ref(other) {}
  27. constexpr OptionalArrayRef(const T& value) noexcept
  28. : wrapped_opt_array_ref(value) {}
  29. template <
  30. typename U = ArrayRef<T>,
  31. std::enable_if_t<
  32. !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
  33. !std::is_same<std::decay_t<U>, in_place_t>::value &&
  34. std::is_constructible<ArrayRef<T>, U&&>::value &&
  35. std::is_convertible<U&&, ArrayRef<T>>::value &&
  36. !std::is_convertible<U&&, T>::value,
  37. bool> = false>
  38. constexpr OptionalArrayRef(U&& value) noexcept(
  39. std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
  40. : wrapped_opt_array_ref(value) {}
  41. template <
  42. typename U = ArrayRef<T>,
  43. std::enable_if_t<
  44. !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
  45. !std::is_same<std::decay_t<U>, in_place_t>::value &&
  46. std::is_constructible<ArrayRef<T>, U&&>::value &&
  47. !std::is_convertible<U&&, ArrayRef<T>>::value,
  48. bool> = false>
  49. constexpr explicit OptionalArrayRef(U&& value) noexcept(
  50. std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
  51. : wrapped_opt_array_ref(value) {}
  52. template <typename... Args>
  53. constexpr explicit OptionalArrayRef(in_place_t ip, Args&&... args) noexcept
  54. : wrapped_opt_array_ref(ip, args...) {}
  55. template <typename U, typename... Args>
  56. constexpr explicit OptionalArrayRef(
  57. in_place_t ip,
  58. std::initializer_list<U> il,
  59. Args&&... args)
  60. : wrapped_opt_array_ref(ip, il, args...) {}
  61. constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
  62. : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
  63. // Destructor
  64. ~OptionalArrayRef() = default;
  65. // Assignment
  66. constexpr OptionalArrayRef& operator=(nullopt_t) noexcept {
  67. wrapped_opt_array_ref = c10::nullopt;
  68. return *this;
  69. }
  70. OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
  71. OptionalArrayRef& operator=(OptionalArrayRef&& other) = default;
  72. constexpr OptionalArrayRef& operator=(
  73. const optional<ArrayRef<T>>& other) noexcept {
  74. wrapped_opt_array_ref = other;
  75. return *this;
  76. }
  77. constexpr OptionalArrayRef& operator=(
  78. optional<ArrayRef<T>>&& other) noexcept {
  79. wrapped_opt_array_ref = other;
  80. return *this;
  81. }
  82. template <typename U = ArrayRef<T>>
  83. constexpr std::enable_if_t<
  84. !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
  85. std::is_constructible<ArrayRef<T>, U&&>::value &&
  86. std::is_assignable<ArrayRef<T>&, U&&>::value,
  87. OptionalArrayRef&>
  88. operator=(U&& value) noexcept(
  89. std::is_nothrow_constructible<ArrayRef<T>, U&&>::value&&
  90. std::is_nothrow_assignable<ArrayRef<T>&, U&&>::value) {
  91. wrapped_opt_array_ref = value;
  92. return *this;
  93. }
  94. // Observers
  95. constexpr ArrayRef<T>* operator->() noexcept {
  96. return &wrapped_opt_array_ref.value();
  97. }
  98. constexpr const ArrayRef<T>* operator->() const noexcept {
  99. return &wrapped_opt_array_ref.value();
  100. }
  101. constexpr ArrayRef<T>& operator*() & noexcept {
  102. return wrapped_opt_array_ref.value();
  103. }
  104. constexpr const ArrayRef<T>& operator*() const& noexcept {
  105. return wrapped_opt_array_ref.value();
  106. }
  107. constexpr ArrayRef<T>&& operator*() && noexcept {
  108. return std::move(wrapped_opt_array_ref.value());
  109. }
  110. constexpr const ArrayRef<T>&& operator*() const&& noexcept {
  111. return std::move(wrapped_opt_array_ref.value());
  112. }
  113. constexpr explicit operator bool() const noexcept {
  114. return wrapped_opt_array_ref.has_value();
  115. }
  116. constexpr bool has_value() const noexcept {
  117. return wrapped_opt_array_ref.has_value();
  118. }
  119. constexpr ArrayRef<T>& value() & {
  120. return wrapped_opt_array_ref.value();
  121. }
  122. constexpr const ArrayRef<T>& value() const& {
  123. return wrapped_opt_array_ref.value();
  124. }
  125. constexpr ArrayRef<T>&& value() && {
  126. return std::move(wrapped_opt_array_ref.value());
  127. }
  128. constexpr const ArrayRef<T>&& value() const&& {
  129. return std::move(wrapped_opt_array_ref.value());
  130. }
  131. template <typename U>
  132. constexpr std::
  133. enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
  134. value_or(U&& default_value) const& {
  135. return wrapped_opt_array_ref.value_or(default_value);
  136. }
  137. template <typename U>
  138. constexpr std::
  139. enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
  140. value_or(U&& default_value) && {
  141. return wrapped_opt_array_ref.value_or(default_value);
  142. }
  143. // Modifiers
  144. constexpr void swap(OptionalArrayRef& other) noexcept {
  145. std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
  146. }
  147. constexpr void reset() noexcept {
  148. wrapped_opt_array_ref.reset();
  149. }
  150. template <typename... Args>
  151. constexpr std::enable_if_t<
  152. std::is_constructible<ArrayRef<T>, Args&&...>::value,
  153. ArrayRef<T>&>
  154. emplace(Args&&... args) noexcept(
  155. std::is_nothrow_constructible<ArrayRef<T>, Args&&...>::value) {
  156. return wrapped_opt_array_ref.emplace(args...);
  157. }
  158. template <typename U, typename... Args>
  159. constexpr ArrayRef<T>& emplace(
  160. std::initializer_list<U> il,
  161. Args&&... args) noexcept {
  162. return wrapped_opt_array_ref.emplace(il, args...);
  163. }
  164. private:
  165. optional<ArrayRef<T>> wrapped_opt_array_ref;
  166. };
  167. using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
  168. inline bool operator==(
  169. const OptionalIntArrayRef& a1,
  170. const IntArrayRef& other) {
  171. if (!a1.has_value()) {
  172. return false;
  173. }
  174. return a1.value() == other;
  175. }
  176. inline bool operator==(
  177. const c10::IntArrayRef& a1,
  178. const c10::OptionalIntArrayRef& a2) {
  179. return a2 == a1;
  180. }
  181. } // namespace c10