intrusive_ptr.h 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. #pragma once
  2. #include <c10/util/C++17.h>
  3. #include <c10/util/Exception.h>
  4. #include <c10/util/ExclusivelyOwned.h>
  5. #include <c10/util/MaybeOwned.h>
  6. #include <atomic>
  7. #include <climits>
  8. #include <memory>
  9. #include <stdexcept>
  10. namespace pybind11 {
  11. template <typename, typename...>
  12. class class_;
  13. }
  14. namespace c10 {
  15. class intrusive_ptr_target;
  16. namespace raw {
  17. namespace weak_intrusive_ptr {
  18. inline void incref(intrusive_ptr_target* self);
  19. }
  20. namespace intrusive_ptr {
  21. inline void incref(intrusive_ptr_target* self);
  22. }
  23. // constructor tag used by intrusive_ptr constructors
  24. struct DontIncreaseRefcount {};
  25. } // namespace raw
  26. /**
  27. * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
  28. * performance because it does the refcounting intrusively
  29. * (i.e. in a member of the object itself).
  30. * Your class T needs to inherit from intrusive_ptr_target to allow it to be
  31. * used in an intrusive_ptr<T>. Your class's constructor should not allow
  32. *`this` to escape to other threads or create an intrusive_ptr from `this`.
  33. */
  34. // Note [Stack allocated intrusive_ptr_target safety]
  35. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  36. // A well known problem with std::enable_shared_from_this is that it
  37. // allows you to create a std::shared_ptr from a stack allocated object,
  38. // which is totally bogus because the object will die once you return
  39. // from the stack. In intrusive_ptr, we can detect that this has occurred,
  40. // because we set the refcount/weakcount of objects which inherit from
  41. // intrusive_ptr_target to zero, *unless* we can prove that the object
  42. // was dynamically allocated (e.g., via make_intrusive).
  43. //
  44. // Thus, whenever you transmute a T* into a intrusive_ptr<T>, we check
  45. // and make sure that the refcount isn't zero (or, a more subtle
  46. // test for weak_intrusive_ptr<T>, for which the refcount may validly
  47. // be zero, but the weak refcount better not be zero), because that
  48. // tells us if the object was allocated by us. If it wasn't, no
  49. // intrusive_ptr for you!
  50. class C10_API intrusive_ptr_target {
  51. // Note [Weak references for intrusive refcounting]
  52. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  53. // Here's the scheme:
  54. //
  55. // - refcount == number of strong references to the object
  56. // weakcount == number of weak references to the object,
  57. // plus one more if refcount > 0
  58. // An invariant: refcount > 0 => weakcount > 0
  59. //
  60. // - c10::StorageImpl stays live as long as there are any strong
  61. // or weak pointers to it (weakcount > 0, since strong
  62. // references count as a +1 to weakcount)
  63. //
  64. // - finalizers are called and data_ptr is deallocated when refcount == 0
  65. //
  66. // - Once refcount == 0, it can never again be > 0 (the transition
  67. // from > 0 to == 0 is monotonic)
  68. //
  69. // - When you access c10::StorageImpl via a weak pointer, you must
  70. // atomically increment the use count, if it is greater than 0.
  71. // If it is not, you must report that the storage is dead.
  72. //
  73. mutable std::atomic<size_t> refcount_;
  74. mutable std::atomic<size_t> weakcount_;
  75. template <typename T, typename NullType>
  76. friend class intrusive_ptr;
  77. friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);
  78. template <typename T, typename NullType>
  79. friend class weak_intrusive_ptr;
  80. friend inline void raw::weak_intrusive_ptr::incref(
  81. intrusive_ptr_target* self);
  82. template <typename T>
  83. friend struct ExclusivelyOwnedTensorTraits;
  84. protected:
  85. // protected destructor. We never want to destruct intrusive_ptr_target*
  86. // directly.
  87. virtual ~intrusive_ptr_target() {
  88. // Disable -Wterminate and -Wexceptions so we're allowed to use assertions
  89. // (i.e. throw exceptions) in a destructor.
  90. // We also have to disable -Wunknown-warning-option and -Wpragmas, because
  91. // some other compilers don't know about -Wterminate or -Wexceptions and
  92. // will show a warning about unknown warning options otherwise.
  93. #if defined(_MSC_VER) && !defined(__clang__)
  94. #pragma warning(push)
  95. #pragma warning( \
  96. disable : 4297) // function assumed not to throw an exception but does
  97. #else
  98. #pragma GCC diagnostic push
  99. #pragma GCC diagnostic ignored "-Wpragmas"
  100. #pragma GCC diagnostic ignored "-Wunknown-warning-option"
  101. #pragma GCC diagnostic ignored "-Wterminate"
  102. #pragma GCC diagnostic ignored "-Wexceptions"
  103. #endif
  104. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  105. // Second condition is there to accommodate
  106. // unsafe_adapt_non_heap_allocated: since we are doing our own
  107. // deallocation in that case, it is correct for each
  108. // expected_decref to have happened (some user code tried to
  109. // decref and thus free the object, but it didn't happen right
  110. // away) or not (no user code tried to free the object, and
  111. // now it's getting destroyed through whatever mechanism the
  112. // caller of unsafe_adapt_non_heap_allocated wanted to
  113. // use). We choose our reference count such that the count
  114. // will not dip below INT_MAX regardless.
  115. refcount_.load() == 0 || refcount_.load() >= INT_MAX,
  116. "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ",
  117. refcount_.load());
  118. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  119. // See ~intrusive_ptr for optimization that will frequently result in 1
  120. // at destruction time.
  121. weakcount_.load() == 1 || weakcount_.load() == 0 ||
  122. weakcount_.load() == INT_MAX - 1 || weakcount_.load() == INT_MAX,
  123. "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
  124. #if defined(_MSC_VER) && !defined(__clang__)
  125. #pragma warning(pop)
  126. #else
  127. #pragma GCC diagnostic pop
  128. #endif
  129. }
  130. constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {}
  131. // intrusive_ptr_target supports copy and move: but refcount and weakcount
  132. // don't participate (since they are intrinsic properties of the memory
  133. // location)
  134. intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept
  135. : intrusive_ptr_target() {}
  136. intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept {
  137. return *this;
  138. }
  139. intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept
  140. : intrusive_ptr_target() {}
  141. intrusive_ptr_target& operator=(
  142. const intrusive_ptr_target& /*other*/) noexcept {
  143. return *this;
  144. }
  145. private:
  146. /**
  147. * This is called when refcount reaches zero.
  148. * You can override this to release expensive resources.
  149. * There might still be weak references, so your object might not get
  150. * destructed yet, but you can assume the object isn't used anymore,
  151. * i.e. no more calls to methods or accesses to members (we just can't
  152. * destruct it yet because we need the weakcount accessible).
  153. *
  154. * If there are no weak references (i.e. your class is about to be
  155. * destructed), this function WILL NOT be called.
  156. */
  157. virtual void release_resources() {}
  158. };
  159. namespace detail {
  160. template <class TTarget>
  161. struct intrusive_target_default_null_type final {
  162. static constexpr TTarget* singleton() noexcept {
  163. return nullptr;
  164. }
  165. };
  166. template <class TTarget, class ToNullType, class FromNullType>
  167. TTarget* assign_ptr_(TTarget* rhs) {
  168. if (FromNullType::singleton() == rhs) {
  169. return ToNullType::singleton();
  170. } else {
  171. return rhs;
  172. }
  173. }
  174. // Increment needs to be acquire-release to make use_count() and
  175. // unique() reliable.
  176. inline size_t atomic_refcount_increment(std::atomic<size_t>& refcount) {
  177. return refcount.fetch_add(1, std::memory_order_acq_rel) + 1;
  178. }
  179. // weak_use_count() is only used for testing, so we don't need it to
  180. // be reliable. Relaxed should be fine.
  181. inline size_t atomic_weakcount_increment(std::atomic<size_t>& weakcount) {
  182. return weakcount.fetch_add(1, std::memory_order_relaxed) + 1;
  183. }
  184. // Both decrements need to be acquire-release for correctness. See
  185. // e.g. std::shared_ptr implementation.
  186. inline size_t atomic_refcount_decrement(std::atomic<size_t>& refcount) {
  187. return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
  188. }
  189. inline size_t atomic_weakcount_decrement(std::atomic<size_t>& weakcount) {
  190. return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
  191. }
  192. } // namespace detail
  193. template <class TTarget, class NullType>
  194. class weak_intrusive_ptr;
  195. template <
  196. class TTarget,
  197. class NullType = detail::intrusive_target_default_null_type<TTarget>>
  198. class intrusive_ptr final {
  199. private:
  200. // the following static assert would be nice to have but it requires
  201. // the target class T to be fully defined when intrusive_ptr<T> is instantiated
  202. // this is a problem for classes that contain pointers to themselves
  203. // static_assert(
  204. // std::is_base_of<intrusive_ptr_target, TTarget>::value,
  205. // "intrusive_ptr can only be used for classes that inherit from
  206. // intrusive_ptr_target.");
  207. #ifndef _WIN32
  208. // This static_assert triggers on MSVC
  209. // error C2131: expression did not evaluate to a constant
  210. static_assert(
  211. NullType::singleton() == NullType::singleton(),
  212. "NullType must have a constexpr singleton() method");
  213. #endif
  214. static_assert(
  215. std::is_base_of<
  216. TTarget,
  217. typename std::remove_pointer<decltype(NullType::singleton())>::type>::
  218. value,
  219. "NullType::singleton() must return a element_type* pointer");
  220. TTarget* target_;
  221. template <typename T>
  222. friend struct ExclusivelyOwnedTensorTraits;
  223. template <class TTarget2, class NullType2>
  224. friend class intrusive_ptr;
  225. friend class weak_intrusive_ptr<TTarget, NullType>;
  226. // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom
  227. // smart holder in pybind11 could access the private constructor of
  228. // intrusive_ptr(T*) which took the ownership of the object. This is required
  229. // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses
  230. // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For
  231. // details, see
  232. // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
  233. template <typename, typename...>
  234. friend class pybind11::class_;
  235. void retain_() {
  236. if (target_ != NullType::singleton()) {
  237. size_t new_refcount =
  238. detail::atomic_refcount_increment(target_->refcount_);
  239. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  240. new_refcount != 1,
  241. "intrusive_ptr: Cannot increase refcount after it reached zero.");
  242. }
  243. }
  244. void reset_() noexcept {
  245. if (target_ != NullType::singleton() &&
  246. detail::atomic_refcount_decrement(target_->refcount_) == 0) {
  247. // See comment above about weakcount. As long as refcount>0,
  248. // weakcount is one larger than the actual number of weak references.
  249. // So we need to decrement it here.
  250. bool should_delete =
  251. target_->weakcount_.load(std::memory_order_acquire) == 1;
  252. if (!should_delete) {
  253. // justification for const_cast: release_resources is basically a
  254. // destructor and a destructor always mutates the object, even for const
  255. // objects. NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
  256. const_cast<std::remove_const_t<TTarget>*>(target_)->release_resources();
  257. should_delete =
  258. detail::atomic_weakcount_decrement(target_->weakcount_) == 0;
  259. }
  260. if (should_delete) {
  261. delete target_;
  262. }
  263. }
  264. }
  265. // raw pointer constructors are not public because we shouldn't make
  266. // intrusive_ptr out of raw pointers except from inside the make_intrusive(),
  267. // reclaim() and weak_intrusive_ptr::lock() implementations.
  268. // This constructor will increase the ref counter for you.
  269. // This constructor will be used by the make_intrusive(), and also pybind11,
  270. // which wrap the intrusive_ptr holder around the raw pointer and incref
  271. // correspondingly (pybind11 requires raw pointer constructor to incref by
  272. // default).
  273. explicit intrusive_ptr(TTarget* target)
  274. : intrusive_ptr(target, raw::DontIncreaseRefcount{}) {
  275. if (target_ != NullType::singleton()) {
  276. // We just created result.target_, so we know no other thread has
  277. // access to it, so we know we needn't care about memory ordering.
  278. // (On x86_64, a store with memory_order_relaxed generates a plain old
  279. // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is
  280. // much more expensive: https://godbolt.org/z/eKPzj8.)
  281. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  282. target_->refcount_ == 0 && target_->weakcount_ == 0,
  283. "intrusive_ptr: Newly-created target had non-zero refcounts. Does its "
  284. "constructor do something strange like incref or create an "
  285. "intrusive_ptr from `this`?");
  286. target_->refcount_.store(1, std::memory_order_relaxed);
  287. target_->weakcount_.store(1, std::memory_order_relaxed);
  288. }
  289. }
  290. public:
  291. using element_type = TTarget;
  292. intrusive_ptr() noexcept
  293. : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
  294. intrusive_ptr(std::nullptr_t) noexcept
  295. : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
  296. // This constructor will not increase the ref counter for you.
  297. // We use the tagged dispatch mechanism to explicitly mark this constructor
  298. // to not increase the refcount
  299. explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
  300. : target_(target) {}
  301. explicit intrusive_ptr(std::unique_ptr<TTarget> rhs) noexcept
  302. : intrusive_ptr(rhs.release()) {}
  303. intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
  304. rhs.target_ = NullType::singleton();
  305. }
  306. template <class From, class FromNullType>
  307. /* implicit */ intrusive_ptr(intrusive_ptr<From, FromNullType>&& rhs) noexcept
  308. : target_(
  309. detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
  310. static_assert(
  311. std::is_convertible<From*, TTarget*>::value,
  312. "Type mismatch. intrusive_ptr move constructor got pointer of wrong type.");
  313. rhs.target_ = FromNullType::singleton();
  314. }
  315. intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) {
  316. retain_();
  317. }
  318. template <class From, class FromNullType>
  319. /* implicit */ intrusive_ptr(const intrusive_ptr<From, FromNullType>& rhs)
  320. : target_(
  321. detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
  322. static_assert(
  323. std::is_convertible<From*, TTarget*>::value,
  324. "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type.");
  325. retain_();
  326. }
  327. ~intrusive_ptr() noexcept {
  328. reset_();
  329. }
  330. intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept {
  331. return operator=<TTarget, NullType>(std::move(rhs));
  332. }
  333. template <class From, class FromNullType>
  334. intrusive_ptr& operator=(intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
  335. static_assert(
  336. std::is_convertible<From*, TTarget*>::value,
  337. "Type mismatch. intrusive_ptr move assignment got pointer of wrong type.");
  338. intrusive_ptr tmp = std::move(rhs);
  339. swap(tmp);
  340. return *this;
  341. }
  342. intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept {
  343. return operator=<TTarget, NullType>(rhs);
  344. }
  345. template <class From, class FromNullType>
  346. intrusive_ptr& operator=(const intrusive_ptr<From, NullType>& rhs) & {
  347. static_assert(
  348. std::is_convertible<From*, TTarget*>::value,
  349. "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type.");
  350. intrusive_ptr tmp = rhs;
  351. swap(tmp);
  352. return *this;
  353. }
  354. TTarget* get() const noexcept {
  355. return target_;
  356. }
  357. TTarget& operator*() const noexcept {
  358. return *target_;
  359. }
  360. TTarget* operator->() const noexcept {
  361. // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
  362. return target_;
  363. }
  364. operator bool() const noexcept {
  365. return target_ != NullType::singleton();
  366. }
  367. void reset() noexcept {
  368. reset_();
  369. target_ = NullType::singleton();
  370. }
  371. void swap(intrusive_ptr& rhs) noexcept {
  372. TTarget* tmp = target_;
  373. target_ = rhs.target_;
  374. rhs.target_ = tmp;
  375. }
  376. // We do a lot of null-pointer checks in our code, good to have this be cheap.
  377. bool defined() const noexcept {
  378. return target_ != NullType::singleton();
  379. }
  380. size_t use_count() const noexcept {
  381. if (target_ == NullType::singleton()) {
  382. return 0;
  383. }
  384. return target_->refcount_.load(std::memory_order_acquire);
  385. }
  386. size_t weak_use_count() const noexcept {
  387. if (target_ == NullType::singleton()) {
  388. return 0;
  389. }
  390. return target_->weakcount_.load(std::memory_order_acquire);
  391. }
  392. bool unique() const noexcept {
  393. return use_count() == 1;
  394. }
  395. /**
  396. * Returns an owning (!) pointer to the underlying object and makes the
  397. * intrusive_ptr instance invalid. That means the refcount is not decreased.
  398. * You *must* put the returned pointer back into a intrusive_ptr using
  399. * intrusive_ptr::reclaim(ptr) to properly destruct it.
  400. * This is helpful for C APIs.
  401. */
  402. TTarget* release() noexcept {
  403. // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
  404. TTarget* result = target_;
  405. target_ = NullType::singleton();
  406. return result;
  407. }
  408. /**
  409. * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes
  410. * over ownership. That means the refcount is not increased.
  411. * This is the counter-part to intrusive_ptr::release() and the pointer
  412. * passed in *must* have been created using intrusive_ptr::release().
  413. */
  414. static intrusive_ptr reclaim(TTarget* owning_ptr) {
  415. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  416. owning_ptr == NullType::singleton() ||
  417. owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(),
  418. "TTarget violates the invariant that refcount > 0 => weakcount > 0");
  419. return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{});
  420. }
  421. /**
  422. * Takes an owning pointer to TTarget* and creates an intrusive_ptr
  423. * representing a new reference, i.e. the raw pointer retains
  424. * ownership.
  425. */
  426. static intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
  427. auto ret = reclaim(owning_ptr);
  428. ret.retain_();
  429. return ret;
  430. }
  431. /**
  432. * Allocate a heap object with args and wrap it inside a intrusive_ptr and
  433. * incref. This is a helper function to let make_intrusive() access private
  434. * intrusive_ptr constructors.
  435. */
  436. template <class... Args>
  437. static intrusive_ptr make(Args&&... args) {
  438. return intrusive_ptr(new TTarget(std::forward<Args>(args)...));
  439. }
  440. /**
  441. * Turn a new instance of TTarget (e.g., literally allocated
  442. * using new TTarget(...) into an intrusive_ptr. If possible,
  443. * use intrusive_ptr::make instead which statically guarantees
  444. * that the allocation was done properly.
  445. *
  446. * At the moment, the only reason this method exists is because
  447. * pybind11 holder types expect to be able to allocate in
  448. * this way (because pybind11 handles the new allocation itself).
  449. */
  450. static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) {
  451. return intrusive_ptr(raw_ptr);
  452. }
  453. /**
  454. * Turn an instance of TTarget that should not be reference counted
  455. * (e.g., allocated into an arena with placement new) into an
  456. * intrusive_ptr. This is gratuitously unsafe and should only be
  457. * used if you can guarantee that the pointer will not escape and be
  458. * refcounted as normal.
  459. *
  460. * `expected_decrefs` is a debugging parameter: it indicates the
  461. * number of strong owners the intrusive_ptr_target in question is
  462. * expected to get. In most use cases, this will likely be 1.
  463. *
  464. * The reason this method exists is for manually sharing
  465. * StorageImpls across Tensors in the static runtime. It needs
  466. * access to private intrusive_ptr members so that the refcounts can
  467. * be initialized to custom values.
  468. */
  469. static intrusive_ptr unsafe_adapt_non_heap_allocated(
  470. TTarget* raw_ptr,
  471. size_t expected_decrefs) {
  472. intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{});
  473. // INT_MAX is impractically huge for a reference count, while
  474. // being in no danger of overflowing size_t. We actually only need to
  475. // initialize the refcount to 2 -- we are just doing an unbalanced
  476. // incref to prevent the non-heap-allocated target from being
  477. // freed, and we are optimizing that incref by directly
  478. // initializing the refcounts rather than doing an expensive
  479. // atomic increment. The reason to use INT_MAX is to accommodate
  480. // the debug assertions in ~intrusive_ptr_target.
  481. #ifdef NDEBUG
  482. expected_decrefs = 0;
  483. #endif
  484. result.target_->refcount_.store(
  485. INT_MAX + expected_decrefs, std::memory_order_relaxed);
  486. result.target_->weakcount_.store(INT_MAX, std::memory_order_relaxed);
  487. return result;
  488. }
  489. /**
  490. * Turn a **non-owning raw pointer** to an intrusive_ptr. It is
  491. * the moral equivalent of enable_shared_from_this on a shared pointer.
  492. *
  493. * This method is only valid for objects that are already live. If
  494. * you are looking for the moral equivalent of unique_ptr<T>(T*)
  495. * constructor, see steal_from_new.
  496. *
  497. * TODO: https://github.com/pytorch/pytorch/issues/56482
  498. */
  499. static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
  500. // See Note [Stack allocated intrusive_ptr_target safety]
  501. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  502. raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
  503. "intrusive_ptr: Can only reclaim pointers that are owned by someone");
  504. auto ptr = reclaim(raw_ptr); // doesn't increase refcount
  505. ptr.retain_();
  506. return ptr;
  507. }
  508. };
  509. template <
  510. class TTarget,
  511. class NullType = detail::intrusive_target_default_null_type<TTarget>,
  512. class... Args>
  513. inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) {
  514. return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...);
  515. }
  516. template <class TTarget, class NullType>
  517. inline void swap(
  518. intrusive_ptr<TTarget, NullType>& lhs,
  519. intrusive_ptr<TTarget, NullType>& rhs) noexcept {
  520. lhs.swap(rhs);
  521. }
  522. // To allow intrusive_ptr inside std::map or std::set, we need operator<
  523. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  524. inline bool operator<(
  525. const intrusive_ptr<TTarget1, NullType1>& lhs,
  526. const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  527. return lhs.get() < rhs.get();
  528. }
  529. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  530. inline bool operator==(
  531. const intrusive_ptr<TTarget1, NullType1>& lhs,
  532. const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  533. return lhs.get() == rhs.get();
  534. }
  535. template <class TTarget1, class NullType1>
  536. inline bool operator==(
  537. const intrusive_ptr<TTarget1, NullType1>& lhs,
  538. std::nullptr_t) noexcept {
  539. return lhs.get() == nullptr;
  540. }
  541. template <class TTarget2, class NullType2>
  542. inline bool operator==(
  543. std::nullptr_t,
  544. const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  545. return nullptr == rhs.get();
  546. }
  547. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  548. inline bool operator!=(
  549. const intrusive_ptr<TTarget1, NullType1>& lhs,
  550. const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  551. return !operator==(lhs, rhs);
  552. }
  553. template <class TTarget1, class NullType1>
  554. inline bool operator!=(
  555. const intrusive_ptr<TTarget1, NullType1>& lhs,
  556. std::nullptr_t) noexcept {
  557. return !operator==(lhs, nullptr);
  558. }
  559. template <class TTarget2, class NullType2>
  560. inline bool operator!=(
  561. std::nullptr_t,
  562. const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  563. return !operator==(nullptr, rhs);
  564. }
  565. template <typename T>
  566. struct MaybeOwnedTraits<c10::intrusive_ptr<T>> {
  567. using owned_type = c10::intrusive_ptr<T>;
  568. using borrow_type = c10::intrusive_ptr<T>;
  569. static borrow_type createBorrow(const owned_type& from) {
  570. return borrow_type::reclaim(from.get());
  571. }
  572. static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
  573. lhs.release();
  574. lhs = borrow_type::reclaim(rhs.get());
  575. }
  576. static void destroyBorrow(borrow_type& toDestroy) {
  577. toDestroy.release();
  578. }
  579. static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
  580. return borrow;
  581. }
  582. static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
  583. return &borrow;
  584. }
  585. static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
  586. return true;
  587. }
  588. };
  589. template <
  590. typename TTarget,
  591. class NullType = detail::intrusive_target_default_null_type<TTarget>>
  592. class weak_intrusive_ptr final {
  593. private:
  594. static_assert(
  595. std::is_base_of<intrusive_ptr_target, TTarget>::value,
  596. "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target.");
  597. #ifndef _WIN32
  598. // This static_assert triggers on MSVC
  599. // error C2131: expression did not evaluate to a constant
  600. static_assert(
  601. NullType::singleton() == NullType::singleton(),
  602. "NullType must have a constexpr singleton() method");
  603. #endif
  604. static_assert(
  605. std::is_base_of<
  606. TTarget,
  607. typename std::remove_pointer<decltype(NullType::singleton())>::type>::
  608. value,
  609. "NullType::singleton() must return a element_type* pointer");
  610. TTarget* target_;
  611. template <class TTarget2, class NullType2>
  612. friend class weak_intrusive_ptr;
  613. void retain_() {
  614. if (target_ != NullType::singleton()) {
  615. size_t new_weakcount =
  616. detail::atomic_weakcount_increment(target_->weakcount_);
  617. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  618. new_weakcount != 1,
  619. "weak_intrusive_ptr: Cannot increase weakcount after it reached zero.");
  620. }
  621. }
  622. void reset_() noexcept {
  623. if (target_ != NullType::singleton() &&
  624. detail::atomic_weakcount_decrement(target_->weakcount_) == 0) {
  625. // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
  626. delete target_;
  627. }
  628. target_ = NullType::singleton();
  629. }
  630. constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {}
  631. public:
  632. using element_type = TTarget;
  633. explicit weak_intrusive_ptr(const intrusive_ptr<TTarget, NullType>& ptr)
  634. : weak_intrusive_ptr(ptr.get()) {
  635. retain_();
  636. }
  637. weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
  638. rhs.target_ = NullType::singleton();
  639. }
  640. template <class From, class FromNullType>
  641. /* implicit */ weak_intrusive_ptr(
  642. weak_intrusive_ptr<From, FromNullType>&& rhs) noexcept
  643. : target_(
  644. detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
  645. static_assert(
  646. std::is_convertible<From*, TTarget*>::value,
  647. "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type.");
  648. rhs.target_ = FromNullType::singleton();
  649. }
  650. weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) {
  651. retain_();
  652. }
  653. template <class From, class FromNullType>
  654. /* implicit */ weak_intrusive_ptr(
  655. const weak_intrusive_ptr<From, FromNullType>& rhs)
  656. : target_(
  657. detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
  658. static_assert(
  659. std::is_convertible<From*, TTarget*>::value,
  660. "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type.");
  661. retain_();
  662. }
  663. ~weak_intrusive_ptr() noexcept {
  664. reset_();
  665. }
  666. weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept {
  667. return operator=<TTarget, NullType>(std::move(rhs));
  668. }
  669. template <class From, class FromNullType>
  670. weak_intrusive_ptr& operator=(
  671. weak_intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
  672. static_assert(
  673. std::is_convertible<From*, TTarget*>::value,
  674. "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type.");
  675. weak_intrusive_ptr tmp = std::move(rhs);
  676. swap(tmp);
  677. return *this;
  678. }
  679. weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept {
  680. return operator=<TTarget, NullType>(rhs);
  681. }
  682. weak_intrusive_ptr& operator=(
  683. const intrusive_ptr<TTarget, NullType>& rhs) & noexcept {
  684. weak_intrusive_ptr tmp(rhs);
  685. swap(tmp);
  686. return *this;
  687. }
  688. template <class From, class FromNullType>
  689. weak_intrusive_ptr& operator=(
  690. const weak_intrusive_ptr<From, NullType>& rhs) & {
  691. static_assert(
  692. std::is_convertible<From*, TTarget*>::value,
  693. "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type.");
  694. weak_intrusive_ptr tmp = rhs;
  695. swap(tmp);
  696. return *this;
  697. }
  698. void reset() noexcept {
  699. reset_();
  700. }
  701. void swap(weak_intrusive_ptr& rhs) noexcept {
  702. TTarget* tmp = target_;
  703. target_ = rhs.target_;
  704. rhs.target_ = tmp;
  705. }
  706. // NB: This should ONLY be used by the std::hash implementation
  707. // for weak_intrusive_ptr. Another way you could do this is
  708. // friend std::hash<weak_intrusive_ptr>, but this triggers two
  709. // bugs:
  710. //
  711. // (1) It triggers an nvcc bug, where std::hash in a friend class
  712. // declaration gets preprocessed into hash, which then cannot
  713. // actually be found. The error in this case looks like:
  714. //
  715. // error: no template named 'hash'; did you mean 'std::hash'?
  716. //
  717. // (2) On OS X, std::hash is declared as a struct, not a class.
  718. // This twings:
  719. //
  720. // error: class 'hash' was previously declared as a struct
  721. // [-Werror,-Wmismatched-tags]
  722. //
  723. // Both of these are work-aroundable, but on the whole, I decided
  724. // it would be simpler and easier to make work if we just expose
  725. // an unsafe getter for target_
  726. //
  727. TTarget* _unsafe_get_target() const noexcept {
  728. return target_;
  729. }
  730. size_t use_count() const noexcept {
  731. if (target_ == NullType::singleton()) {
  732. return 0;
  733. }
  734. return target_->refcount_.load(
  735. std::memory_order_acquire); // refcount, not weakcount!
  736. }
  737. size_t weak_use_count() const noexcept {
  738. if (target_ == NullType::singleton()) {
  739. return 0;
  740. }
  741. return target_->weakcount_.load(std::memory_order_acquire);
  742. }
  743. bool expired() const noexcept {
  744. return use_count() == 0;
  745. }
  746. intrusive_ptr<TTarget, NullType> lock() const noexcept {
  747. if (expired()) {
  748. return intrusive_ptr<TTarget, NullType>();
  749. } else {
  750. auto refcount = target_->refcount_.load(std::memory_order_seq_cst);
  751. do {
  752. if (refcount == 0) {
  753. // Object already destructed, no strong references left anymore.
  754. // Return nullptr.
  755. return intrusive_ptr<TTarget, NullType>();
  756. }
  757. } while (
  758. !target_->refcount_.compare_exchange_weak(refcount, refcount + 1));
  759. return intrusive_ptr<TTarget, NullType>(
  760. target_, raw::DontIncreaseRefcount{});
  761. }
  762. }
  763. /**
  764. * Returns an owning (but still only weakly referenced) pointer to the
  765. * underlying object and makes the weak_intrusive_ptr instance invalid.
  766. * That means the weakcount is not decreased.
  767. * You *must* put the returned pointer back into a weak_intrusive_ptr using
  768. * weak_intrusive_ptr::reclaim(ptr) to properly destruct it.
  769. * This is helpful for C APIs.
  770. */
  771. TTarget* release() noexcept {
  772. TTarget* result = target_;
  773. target_ = NullType::singleton();
  774. return result;
  775. }
  776. /**
  777. * Takes an owning (but must be weakly referenced) pointer to TTarget* and
  778. * creates a weak_intrusive_ptr that takes over ownership.
  779. * This means that the weakcount is not increased.
  780. * This is the counter-part to weak_intrusive_ptr::release() and the pointer
  781. * passed in *must* have been created using weak_intrusive_ptr::release().
  782. */
  783. static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) {
  784. // See Note [Stack allocated intrusive_ptr_target safety]
  785. // if refcount > 0, weakcount must be >1 for weak references to exist.
  786. // see weak counting explanation at top of this file.
  787. // if refcount == 0, weakcount only must be >0.
  788. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  789. owning_weak_ptr == NullType::singleton() ||
  790. owning_weak_ptr->weakcount_.load() > 1 ||
  791. (owning_weak_ptr->refcount_.load() == 0 &&
  792. owning_weak_ptr->weakcount_.load() > 0),
  793. "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release().");
  794. return weak_intrusive_ptr(owning_weak_ptr);
  795. }
  796. /**
  797. * Takes a pointer to TTarget* (may be weak or strong) and creates a
  798. * new weak_intrusive_ptr representing a new weak reference, i.e.
  799. * the raw pointer retains ownership.
  800. */
  801. static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
  802. auto ret = reclaim(owning_ptr);
  803. ret.retain_();
  804. return ret;
  805. }
  806. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  807. friend bool operator<(
  808. const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
  809. const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
  810. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  811. friend bool operator==(
  812. const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
  813. const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
  814. };
  815. template <class TTarget, class NullType>
  816. inline void swap(
  817. weak_intrusive_ptr<TTarget, NullType>& lhs,
  818. weak_intrusive_ptr<TTarget, NullType>& rhs) noexcept {
  819. lhs.swap(rhs);
  820. }
  821. // To allow weak_intrusive_ptr inside std::map or std::set, we need operator<
  822. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  823. inline bool operator<(
  824. const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
  825. const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  826. return lhs.target_ < rhs.target_;
  827. }
  828. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  829. inline bool operator==(
  830. const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
  831. const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  832. return lhs.target_ == rhs.target_;
  833. }
  834. template <class TTarget1, class NullType1, class TTarget2, class NullType2>
  835. inline bool operator!=(
  836. const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
  837. const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
  838. return !operator==(lhs, rhs);
  839. }
  840. // Alias for documentary purposes, to more easily distinguish
  841. // weak raw intrusive pointers from intrusive pointers.
  842. using weak_intrusive_ptr_target = intrusive_ptr_target;
  843. // This namespace provides some methods for working with
  844. // raw pointers that subclass intrusive_ptr_target. They are not provided
  845. // as methods on intrusive_ptr_target, because ideally you would not need these
  846. // methods at all (use smart pointers), but if you are dealing with legacy code
  847. // that still needs to pass around raw pointers, you may find these quite
  848. // useful.
  849. //
  850. // An important usage note: some functions are only valid if you have a
  851. // strong raw pointer to the object, while others are only valid if you
  852. // have a weak raw pointer to the object. ONLY call intrusive_ptr namespace
  853. // functions on strong pointers, and weak_intrusive_ptr namespace functions
  854. // on weak pointers. If you mix it up, you may get an assert failure.
  855. namespace raw {
  856. namespace intrusive_ptr {
  857. // WARNING: Unlike the reclaim() API, it is NOT valid to pass
  858. // NullType::singleton to this function
  859. inline void incref(intrusive_ptr_target* self) {
  860. if (self) {
  861. detail::atomic_refcount_increment(self->refcount_);
  862. }
  863. }
  864. // WARNING: Unlike the reclaim() API, it is NOT valid to pass
  865. // NullType::singleton to this function
  866. inline void decref(intrusive_ptr_target* self) {
  867. // Let it die
  868. c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
  869. // NB: Caller still has 'self' pointer, but it's now invalid.
  870. // If you want more safety, used the actual c10::intrusive_ptr class
  871. }
  872. template <typename T>
  873. inline T* make_weak(T* self) {
  874. // NB: 'this' is a strong pointer, but we return a weak pointer
  875. auto ptr = c10::intrusive_ptr<T>::reclaim(self);
  876. c10::weak_intrusive_ptr<T> wptr(ptr);
  877. ptr.release();
  878. return wptr.release();
  879. }
  880. inline size_t use_count(intrusive_ptr_target* self) {
  881. auto ptr = c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
  882. auto r = ptr.use_count();
  883. ptr.release();
  884. return r;
  885. }
  886. } // namespace intrusive_ptr
  887. namespace weak_intrusive_ptr {
  888. inline void incref(weak_intrusive_ptr_target* self) {
  889. detail::atomic_weakcount_increment(self->weakcount_);
  890. }
  891. inline void decref(weak_intrusive_ptr_target* self) {
  892. // Let it die
  893. c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
  894. // NB: You still "have" the 'self' pointer, but it's now invalid.
  895. // If you want more safety, used the actual c10::weak_intrusive_ptr class
  896. }
  897. template <typename T>
  898. inline T* lock(T* self) {
  899. auto wptr = c10::weak_intrusive_ptr<T>::reclaim(self);
  900. auto ptr = wptr.lock();
  901. wptr.release();
  902. return ptr.release();
  903. }
  904. // This gives the STRONG refcount of a WEAK pointer
  905. inline size_t use_count(weak_intrusive_ptr_target* self) {
  906. auto wptr = c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
  907. auto r = wptr.use_count();
  908. wptr.release();
  909. return r;
  910. }
  911. } // namespace weak_intrusive_ptr
  912. } // namespace raw
  913. } // namespace c10
  914. namespace std {
  915. // To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or
  916. // std::unordered_set, we need std::hash
  917. template <class TTarget, class NullType>
  918. struct hash<c10::intrusive_ptr<TTarget, NullType>> {
  919. size_t operator()(const c10::intrusive_ptr<TTarget, NullType>& x) const {
  920. return std::hash<TTarget*>()(x.get());
  921. }
  922. };
  923. template <class TTarget, class NullType>
  924. struct hash<c10::weak_intrusive_ptr<TTarget, NullType>> {
  925. size_t operator()(const c10::weak_intrusive_ptr<TTarget, NullType>& x) const {
  926. return std::hash<TTarget*>()(x._unsafe_get_target());
  927. }
  928. };
  929. } // namespace std