// This file defines OptionalArrayRef, a class that has almost the same // exact functionality as c10::optional>, except that its // converting constructor fixes a dangling pointer issue. // // The implicit converting constructor of both c10::optional> and // std::optional> can cause the underlying ArrayRef to store // a dangling pointer. OptionalArrayRef prevents this by wrapping // a c10::optional> and fixing the constructor implementation. // // See https://github.com/pytorch/pytorch/issues/63645 for more on this. #pragma once #include #include namespace c10 { template class OptionalArrayRef final { public: // Constructors constexpr OptionalArrayRef() noexcept = default; constexpr OptionalArrayRef(nullopt_t) noexcept {} OptionalArrayRef(const OptionalArrayRef& other) = default; OptionalArrayRef(OptionalArrayRef&& other) = default; constexpr OptionalArrayRef(const optional>& other) noexcept : wrapped_opt_array_ref(other) {} constexpr OptionalArrayRef(optional>&& other) noexcept : wrapped_opt_array_ref(other) {} constexpr OptionalArrayRef(const T& value) noexcept : wrapped_opt_array_ref(value) {} template < typename U = ArrayRef, std::enable_if_t< !std::is_same, OptionalArrayRef>::value && !std::is_same, in_place_t>::value && std::is_constructible, U&&>::value && std::is_convertible>::value && !std::is_convertible::value, bool> = false> constexpr OptionalArrayRef(U&& value) noexcept( std::is_nothrow_constructible, U&&>::value) : wrapped_opt_array_ref(value) {} template < typename U = ArrayRef, std::enable_if_t< !std::is_same, OptionalArrayRef>::value && !std::is_same, in_place_t>::value && std::is_constructible, U&&>::value && !std::is_convertible>::value, bool> = false> constexpr explicit OptionalArrayRef(U&& value) noexcept( std::is_nothrow_constructible, U&&>::value) : wrapped_opt_array_ref(value) {} template constexpr explicit OptionalArrayRef(in_place_t ip, Args&&... args) noexcept : wrapped_opt_array_ref(ip, args...) {} template constexpr explicit OptionalArrayRef( in_place_t ip, std::initializer_list il, Args&&... args) : wrapped_opt_array_ref(ip, il, args...) {} constexpr OptionalArrayRef(const std::initializer_list& Vec) : wrapped_opt_array_ref(ArrayRef(Vec)) {} // Destructor ~OptionalArrayRef() = default; // Assignment constexpr OptionalArrayRef& operator=(nullopt_t) noexcept { wrapped_opt_array_ref = c10::nullopt; return *this; } OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; OptionalArrayRef& operator=(OptionalArrayRef&& other) = default; constexpr OptionalArrayRef& operator=( const optional>& other) noexcept { wrapped_opt_array_ref = other; return *this; } constexpr OptionalArrayRef& operator=( optional>&& other) noexcept { wrapped_opt_array_ref = other; return *this; } template > constexpr std::enable_if_t< !std::is_same, OptionalArrayRef>::value && std::is_constructible, U&&>::value && std::is_assignable&, U&&>::value, OptionalArrayRef&> operator=(U&& value) noexcept( std::is_nothrow_constructible, U&&>::value&& std::is_nothrow_assignable&, U&&>::value) { wrapped_opt_array_ref = value; return *this; } // Observers constexpr ArrayRef* operator->() noexcept { return &wrapped_opt_array_ref.value(); } constexpr const ArrayRef* operator->() const noexcept { return &wrapped_opt_array_ref.value(); } constexpr ArrayRef& operator*() & noexcept { return wrapped_opt_array_ref.value(); } constexpr const ArrayRef& operator*() const& noexcept { return wrapped_opt_array_ref.value(); } constexpr ArrayRef&& operator*() && noexcept { return std::move(wrapped_opt_array_ref.value()); } constexpr const ArrayRef&& operator*() const&& noexcept { return std::move(wrapped_opt_array_ref.value()); } constexpr explicit operator bool() const noexcept { return wrapped_opt_array_ref.has_value(); } constexpr bool has_value() const noexcept { return wrapped_opt_array_ref.has_value(); } constexpr ArrayRef& value() & { return wrapped_opt_array_ref.value(); } constexpr const ArrayRef& value() const& { return wrapped_opt_array_ref.value(); } constexpr ArrayRef&& value() && { return std::move(wrapped_opt_array_ref.value()); } constexpr const ArrayRef&& value() const&& { return std::move(wrapped_opt_array_ref.value()); } template constexpr std:: enable_if_t>::value, ArrayRef> value_or(U&& default_value) const& { return wrapped_opt_array_ref.value_or(default_value); } template constexpr std:: enable_if_t>::value, ArrayRef> value_or(U&& default_value) && { return wrapped_opt_array_ref.value_or(default_value); } // Modifiers constexpr void swap(OptionalArrayRef& other) noexcept { std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); } constexpr void reset() noexcept { wrapped_opt_array_ref.reset(); } template constexpr std::enable_if_t< std::is_constructible, Args&&...>::value, ArrayRef&> emplace(Args&&... args) noexcept( std::is_nothrow_constructible, Args&&...>::value) { return wrapped_opt_array_ref.emplace(args...); } template constexpr ArrayRef& emplace( std::initializer_list il, Args&&... args) noexcept { return wrapped_opt_array_ref.emplace(il, args...); } private: optional> wrapped_opt_array_ref; }; using OptionalIntArrayRef = OptionalArrayRef; inline bool operator==( const OptionalIntArrayRef& a1, const IntArrayRef& other) { if (!a1.has_value()) { return false; } return a1.value() == other; } inline bool operator==( const c10::IntArrayRef& a1, const c10::OptionalIntArrayRef& a2) { return a2 == a1; } } // namespace c10