sparse_bitset.h 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. //===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file defines the SparseBitVector class. See the doxygen comment for
  10. // SparseBitVector for more details on the algorithm used.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #pragma once
  14. #include <c10/macros/Macros.h>
  15. #include <c10/util/llvmMathExtras.h>
  16. #include <cassert>
  17. #include <climits>
  18. #include <cstring>
  19. #include <iterator>
  20. #include <list>
  21. C10_CLANG_DIAGNOSTIC_PUSH()
  22. #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
  23. C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
  24. #endif
  25. namespace c10 {
  26. /// SparseBitVector is an implementation of a bitvector that is sparse by only
  27. /// storing the elements that have non-zero bits set. In order to make this
  28. /// fast for the most common cases, SparseBitVector is implemented as a linked
  29. /// list of SparseBitVectorElements. We maintain a pointer to the last
  30. /// SparseBitVectorElement accessed (in the form of a list iterator), in order
  31. /// to make multiple in-order test/set constant time after the first one is
  32. /// executed. Note that using vectors to store SparseBitVectorElement's does
  33. /// not work out very well because it causes insertion in the middle to take
  34. /// enormous amounts of time with a large amount of bits. Other structures that
  35. /// have better worst cases for insertion in the middle (various balanced trees,
  36. /// etc) do not perform as well in practice as a linked list with this iterator
  37. /// kept up to date. They are also significantly more memory intensive.
  38. template <unsigned ElementSize = 128>
  39. struct SparseBitVectorElement {
  40. public:
  41. using BitWord = unsigned long;
  42. using size_type = unsigned;
  43. enum {
  44. BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
  45. BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
  46. BITS_PER_ELEMENT = ElementSize
  47. };
  48. private:
  49. // Index of Element in terms of where first bit starts.
  50. unsigned ElementIndex;
  51. BitWord Bits[BITWORDS_PER_ELEMENT];
  52. SparseBitVectorElement() {
  53. ElementIndex = ~0U;
  54. memset(&Bits[0], 0, sizeof(BitWord) * BITWORDS_PER_ELEMENT);
  55. }
  56. public:
  57. explicit SparseBitVectorElement(unsigned Idx) {
  58. ElementIndex = Idx;
  59. memset(&Bits[0], 0, sizeof(BitWord) * BITWORDS_PER_ELEMENT);
  60. }
  61. // Comparison.
  62. bool operator==(const SparseBitVectorElement& RHS) const {
  63. if (ElementIndex != RHS.ElementIndex)
  64. return false;
  65. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  66. if (Bits[i] != RHS.Bits[i])
  67. return false;
  68. return true;
  69. }
  70. bool operator!=(const SparseBitVectorElement& RHS) const {
  71. return !(*this == RHS);
  72. }
  73. // Return the bits that make up word Idx in our element.
  74. BitWord word(unsigned Idx) const {
  75. assert(Idx < BITWORDS_PER_ELEMENT);
  76. return Bits[Idx];
  77. }
  78. unsigned index() const {
  79. return ElementIndex;
  80. }
  81. bool empty() const {
  82. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  83. if (Bits[i])
  84. return false;
  85. return true;
  86. }
  87. void set(unsigned Idx) {
  88. Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
  89. }
  90. bool test_and_set(unsigned Idx) {
  91. bool old = test(Idx);
  92. if (!old) {
  93. set(Idx);
  94. return true;
  95. }
  96. return false;
  97. }
  98. void reset(unsigned Idx) {
  99. Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
  100. }
  101. bool test(unsigned Idx) const {
  102. return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
  103. }
  104. size_type count() const {
  105. unsigned NumBits = 0;
  106. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  107. NumBits += llvm::countPopulation(Bits[i]);
  108. return NumBits;
  109. }
  110. /// find_first - Returns the index of the first set bit.
  111. int find_first() const {
  112. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
  113. if (Bits[i] != 0)
  114. return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
  115. throw std::runtime_error("Illegal empty element");
  116. }
  117. /// find_last - Returns the index of the last set bit.
  118. int find_last() const {
  119. for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
  120. unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
  121. if (Bits[Idx] != 0)
  122. return Idx * BITWORD_SIZE + BITWORD_SIZE -
  123. llvm::countLeadingZeros(Bits[Idx]);
  124. }
  125. throw std::runtime_error("Illegal empty element");
  126. }
  127. /// find_next - Returns the index of the next set bit starting from the
  128. /// "Curr" bit. Returns -1 if the next set bit is not found.
  129. int find_next(unsigned Curr) const {
  130. if (Curr >= BITS_PER_ELEMENT)
  131. return -1;
  132. unsigned WordPos = Curr / BITWORD_SIZE;
  133. unsigned BitPos = Curr % BITWORD_SIZE;
  134. BitWord Copy = Bits[WordPos];
  135. assert(
  136. WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element");
  137. // Mask off previous bits.
  138. Copy &= ~0UL << BitPos;
  139. if (Copy != 0)
  140. return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);
  141. // Check subsequent words.
  142. for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i)
  143. if (Bits[i] != 0)
  144. return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
  145. return -1;
  146. }
  147. // Union this element with RHS and return true if this one changed.
  148. bool unionWith(const SparseBitVectorElement& RHS) {
  149. bool changed = false;
  150. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  151. BitWord old = changed ? 0 : Bits[i];
  152. Bits[i] |= RHS.Bits[i];
  153. if (!changed && old != Bits[i])
  154. changed = true;
  155. }
  156. return changed;
  157. }
  158. // Return true if we have any bits in common with RHS
  159. bool intersects(const SparseBitVectorElement& RHS) const {
  160. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  161. if (RHS.Bits[i] & Bits[i])
  162. return true;
  163. }
  164. return false;
  165. }
  166. // Intersect this Element with RHS and return true if this one changed.
  167. // BecameZero is set to true if this element became all-zero bits.
  168. bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) {
  169. bool changed = false;
  170. bool allzero = true;
  171. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  172. BitWord old = changed ? 0 : Bits[i];
  173. Bits[i] &= RHS.Bits[i];
  174. if (Bits[i] != 0)
  175. allzero = false;
  176. if (!changed && old != Bits[i])
  177. changed = true;
  178. }
  179. BecameZero = allzero;
  180. return changed;
  181. }
  182. // Intersect this Element with the complement of RHS and return true if this
  183. // one changed. BecameZero is set to true if this element became all-zero
  184. // bits.
  185. bool intersectWithComplement(
  186. const SparseBitVectorElement& RHS,
  187. bool& BecameZero) {
  188. bool changed = false;
  189. bool allzero = true;
  190. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  191. BitWord old = changed ? 0 : Bits[i];
  192. Bits[i] &= ~RHS.Bits[i];
  193. if (Bits[i] != 0)
  194. allzero = false;
  195. if (!changed && old != Bits[i])
  196. changed = true;
  197. }
  198. BecameZero = allzero;
  199. return changed;
  200. }
  201. // Three argument version of intersectWithComplement that intersects
  202. // RHS1 & ~RHS2 into this element
  203. void intersectWithComplement(
  204. const SparseBitVectorElement& RHS1,
  205. const SparseBitVectorElement& RHS2,
  206. bool& BecameZero) {
  207. bool allzero = true;
  208. for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
  209. Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
  210. if (Bits[i] != 0)
  211. allzero = false;
  212. }
  213. BecameZero = allzero;
  214. }
  215. };
  216. template <unsigned ElementSize = 128>
  217. class SparseBitVector {
  218. using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
  219. using ElementListIter = typename ElementList::iterator;
  220. using ElementListConstIter = typename ElementList::const_iterator;
  221. enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE };
  222. ElementList Elements;
  223. // Pointer to our current Element. This has no visible effect on the external
  224. // state of a SparseBitVector, it's just used to improve performance in the
  225. // common case of testing/modifying bits with similar indices.
  226. mutable ElementListIter CurrElementIter;
  227. // This is like std::lower_bound, except we do linear searching from the
  228. // current position.
  229. ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {
  230. // We cache a non-const iterator so we're forced to resort to const_cast to
  231. // get the begin/end in the case where 'this' is const. To avoid duplication
  232. // of code with the only difference being whether the const cast is present
  233. // 'this' is always const in this particular function and we sort out the
  234. // difference in FindLowerBound and FindLowerBoundConst.
  235. ElementListIter Begin =
  236. const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin();
  237. ElementListIter End =
  238. const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end();
  239. if (Elements.empty()) {
  240. CurrElementIter = Begin;
  241. return CurrElementIter;
  242. }
  243. // Make sure our current iterator is valid.
  244. if (CurrElementIter == End)
  245. --CurrElementIter;
  246. // Search from our current iterator, either backwards or forwards,
  247. // depending on what element we are looking for.
  248. ElementListIter ElementIter = CurrElementIter;
  249. if (CurrElementIter->index() == ElementIndex) {
  250. return ElementIter;
  251. } else if (CurrElementIter->index() > ElementIndex) {
  252. while (ElementIter != Begin && ElementIter->index() > ElementIndex)
  253. --ElementIter;
  254. } else {
  255. while (ElementIter != End && ElementIter->index() < ElementIndex)
  256. ++ElementIter;
  257. }
  258. CurrElementIter = ElementIter;
  259. return ElementIter;
  260. }
  261. ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
  262. return FindLowerBoundImpl(ElementIndex);
  263. }
  264. ElementListIter FindLowerBound(unsigned ElementIndex) {
  265. return FindLowerBoundImpl(ElementIndex);
  266. }
  267. // Iterator to walk set bits in the bitmap. This iterator is a lot uglier
  268. // than it would be, in order to be efficient.
  269. class SparseBitVectorIterator {
  270. private:
  271. bool AtEnd;
  272. const SparseBitVector<ElementSize>* BitVector = nullptr;
  273. // Current element inside of bitmap.
  274. ElementListConstIter Iter;
  275. // Current bit number inside of our bitmap.
  276. unsigned BitNumber;
  277. // Current word number inside of our element.
  278. unsigned WordNumber;
  279. // Current bits from the element.
  280. typename SparseBitVectorElement<ElementSize>::BitWord Bits;
  281. // Move our iterator to the first non-zero bit in the bitmap.
  282. void AdvanceToFirstNonZero() {
  283. if (AtEnd)
  284. return;
  285. if (BitVector->Elements.empty()) {
  286. AtEnd = true;
  287. return;
  288. }
  289. Iter = BitVector->Elements.begin();
  290. BitNumber = Iter->index() * ElementSize;
  291. unsigned BitPos = Iter->find_first();
  292. BitNumber += BitPos;
  293. WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
  294. Bits = Iter->word(WordNumber);
  295. Bits >>= BitPos % BITWORD_SIZE;
  296. }
  297. // Move our iterator to the next non-zero bit.
  298. void AdvanceToNextNonZero() {
  299. if (AtEnd)
  300. return;
  301. while (Bits && !(Bits & 1)) {
  302. Bits >>= 1;
  303. BitNumber += 1;
  304. }
  305. // See if we ran out of Bits in this word.
  306. if (!Bits) {
  307. int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize);
  308. // If we ran out of set bits in this element, move to next element.
  309. if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) {
  310. ++Iter;
  311. WordNumber = 0;
  312. // We may run out of elements in the bitmap.
  313. if (Iter == BitVector->Elements.end()) {
  314. AtEnd = true;
  315. return;
  316. }
  317. // Set up for next non-zero word in bitmap.
  318. BitNumber = Iter->index() * ElementSize;
  319. NextSetBitNumber = Iter->find_first();
  320. BitNumber += NextSetBitNumber;
  321. WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
  322. Bits = Iter->word(WordNumber);
  323. Bits >>= NextSetBitNumber % BITWORD_SIZE;
  324. } else {
  325. WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE;
  326. Bits = Iter->word(WordNumber);
  327. Bits >>= NextSetBitNumber % BITWORD_SIZE;
  328. BitNumber = Iter->index() * ElementSize;
  329. BitNumber += NextSetBitNumber;
  330. }
  331. }
  332. }
  333. public:
  334. SparseBitVectorIterator() = default;
  335. SparseBitVectorIterator(
  336. const SparseBitVector<ElementSize>* RHS,
  337. bool end = false)
  338. : BitVector(RHS) {
  339. Iter = BitVector->Elements.begin();
  340. BitNumber = 0;
  341. Bits = 0;
  342. WordNumber = ~0;
  343. AtEnd = end;
  344. AdvanceToFirstNonZero();
  345. }
  346. // Preincrement.
  347. inline SparseBitVectorIterator& operator++() {
  348. ++BitNumber;
  349. Bits >>= 1;
  350. AdvanceToNextNonZero();
  351. return *this;
  352. }
  353. // Postincrement.
  354. inline SparseBitVectorIterator operator++(int) {
  355. SparseBitVectorIterator tmp = *this;
  356. ++*this;
  357. return tmp;
  358. }
  359. // Return the current set bit number.
  360. unsigned operator*() const {
  361. return BitNumber;
  362. }
  363. bool operator==(const SparseBitVectorIterator& RHS) const {
  364. // If they are both at the end, ignore the rest of the fields.
  365. if (AtEnd && RHS.AtEnd)
  366. return true;
  367. // Otherwise they are the same if they have the same bit number and
  368. // bitmap.
  369. return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber;
  370. }
  371. bool operator!=(const SparseBitVectorIterator& RHS) const {
  372. return !(*this == RHS);
  373. }
  374. };
  375. public:
  376. using iterator = SparseBitVectorIterator;
  377. SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {}
  378. SparseBitVector(const SparseBitVector& RHS)
  379. : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {}
  380. SparseBitVector(SparseBitVector&& RHS)
  381. : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {}
  382. // Clear.
  383. void clear() {
  384. Elements.clear();
  385. }
  386. // Assignment
  387. SparseBitVector& operator=(const SparseBitVector& RHS) {
  388. if (this == &RHS)
  389. return *this;
  390. Elements = RHS.Elements;
  391. CurrElementIter = Elements.begin();
  392. return *this;
  393. }
  394. SparseBitVector& operator=(SparseBitVector&& RHS) {
  395. Elements = std::move(RHS.Elements);
  396. CurrElementIter = Elements.begin();
  397. return *this;
  398. }
  399. // Test, Reset, and Set a bit in the bitmap.
  400. bool test(unsigned Idx) const {
  401. if (Elements.empty())
  402. return false;
  403. unsigned ElementIndex = Idx / ElementSize;
  404. ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex);
  405. // If we can't find an element that is supposed to contain this bit, there
  406. // is nothing more to do.
  407. if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
  408. return false;
  409. return ElementIter->test(Idx % ElementSize);
  410. }
  411. void reset(unsigned Idx) {
  412. if (Elements.empty())
  413. return;
  414. unsigned ElementIndex = Idx / ElementSize;
  415. ElementListIter ElementIter = FindLowerBound(ElementIndex);
  416. // If we can't find an element that is supposed to contain this bit, there
  417. // is nothing more to do.
  418. if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
  419. return;
  420. ElementIter->reset(Idx % ElementSize);
  421. // When the element is zeroed out, delete it.
  422. if (ElementIter->empty()) {
  423. ++CurrElementIter;
  424. Elements.erase(ElementIter);
  425. }
  426. }
  427. void set(unsigned Idx) {
  428. unsigned ElementIndex = Idx / ElementSize;
  429. ElementListIter ElementIter;
  430. if (Elements.empty()) {
  431. ElementIter = Elements.emplace(Elements.end(), ElementIndex);
  432. } else {
  433. ElementIter = FindLowerBound(ElementIndex);
  434. if (ElementIter == Elements.end() ||
  435. ElementIter->index() != ElementIndex) {
  436. // We may have hit the beginning of our SparseBitVector, in which case,
  437. // we may need to insert right after this element, which requires moving
  438. // the current iterator forward one, because insert does insert before.
  439. if (ElementIter != Elements.end() &&
  440. ElementIter->index() < ElementIndex)
  441. ++ElementIter;
  442. ElementIter = Elements.emplace(ElementIter, ElementIndex);
  443. }
  444. }
  445. CurrElementIter = ElementIter;
  446. ElementIter->set(Idx % ElementSize);
  447. }
  448. bool test_and_set(unsigned Idx) {
  449. bool old = test(Idx);
  450. if (!old) {
  451. set(Idx);
  452. return true;
  453. }
  454. return false;
  455. }
  456. bool operator!=(const SparseBitVector& RHS) const {
  457. return !(*this == RHS);
  458. }
  459. bool operator==(const SparseBitVector& RHS) const {
  460. ElementListConstIter Iter1 = Elements.begin();
  461. ElementListConstIter Iter2 = RHS.Elements.begin();
  462. for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end();
  463. ++Iter1, ++Iter2) {
  464. if (*Iter1 != *Iter2)
  465. return false;
  466. }
  467. return Iter1 == Elements.end() && Iter2 == RHS.Elements.end();
  468. }
  469. // Union our bitmap with the RHS and return true if we changed.
  470. bool operator|=(const SparseBitVector& RHS) {
  471. if (this == &RHS)
  472. return false;
  473. if (empty()) {
  474. *this = RHS;
  475. return true;
  476. }
  477. bool changed = false;
  478. ElementListIter Iter1 = Elements.begin();
  479. ElementListConstIter Iter2 = RHS.Elements.begin();
  480. // If RHS is empty, we are done
  481. if (RHS.Elements.empty())
  482. return false;
  483. while (Iter2 != RHS.Elements.end()) {
  484. if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) {
  485. Elements.insert(Iter1, *Iter2);
  486. ++Iter2;
  487. changed = true;
  488. } else if (Iter1->index() == Iter2->index()) {
  489. changed |= Iter1->unionWith(*Iter2);
  490. ++Iter1;
  491. ++Iter2;
  492. } else {
  493. ++Iter1;
  494. }
  495. }
  496. CurrElementIter = Elements.begin();
  497. return changed;
  498. }
  499. // Intersect our bitmap with the RHS and return true if ours changed.
  500. bool operator-=(const SparseBitVector& RHS) {
  501. return intersectWithComplement(RHS);
  502. }
  503. // Intersect our bitmap with the RHS and return true if ours changed.
  504. bool operator&=(const SparseBitVector& RHS) {
  505. if (this == &RHS)
  506. return false;
  507. bool changed = false;
  508. ElementListIter Iter1 = Elements.begin();
  509. ElementListConstIter Iter2 = RHS.Elements.begin();
  510. // Check if both bitmaps are empty.
  511. if (Elements.empty() && RHS.Elements.empty())
  512. return false;
  513. // Loop through, intersecting as we go, erasing elements when necessary.
  514. while (Iter2 != RHS.Elements.end()) {
  515. if (Iter1 == Elements.end()) {
  516. CurrElementIter = Elements.begin();
  517. return changed;
  518. }
  519. if (Iter1->index() > Iter2->index()) {
  520. ++Iter2;
  521. } else if (Iter1->index() == Iter2->index()) {
  522. bool BecameZero;
  523. changed |= Iter1->intersectWith(*Iter2, BecameZero);
  524. if (BecameZero) {
  525. ElementListIter IterTmp = Iter1;
  526. ++Iter1;
  527. Elements.erase(IterTmp);
  528. } else {
  529. ++Iter1;
  530. }
  531. ++Iter2;
  532. } else {
  533. ElementListIter IterTmp = Iter1;
  534. ++Iter1;
  535. Elements.erase(IterTmp);
  536. changed = true;
  537. }
  538. }
  539. if (Iter1 != Elements.end()) {
  540. Elements.erase(Iter1, Elements.end());
  541. changed = true;
  542. }
  543. CurrElementIter = Elements.begin();
  544. return changed;
  545. }
  546. // Intersect our bitmap with the complement of the RHS and return true
  547. // if ours changed.
  548. bool intersectWithComplement(const SparseBitVector& RHS) {
  549. if (this == &RHS) {
  550. if (!empty()) {
  551. clear();
  552. return true;
  553. }
  554. return false;
  555. }
  556. bool changed = false;
  557. ElementListIter Iter1 = Elements.begin();
  558. ElementListConstIter Iter2 = RHS.Elements.begin();
  559. // If either our bitmap or RHS is empty, we are done
  560. if (Elements.empty() || RHS.Elements.empty())
  561. return false;
  562. // Loop through, intersecting as we go, erasing elements when necessary.
  563. while (Iter2 != RHS.Elements.end()) {
  564. if (Iter1 == Elements.end()) {
  565. CurrElementIter = Elements.begin();
  566. return changed;
  567. }
  568. if (Iter1->index() > Iter2->index()) {
  569. ++Iter2;
  570. } else if (Iter1->index() == Iter2->index()) {
  571. bool BecameZero;
  572. changed |= Iter1->intersectWithComplement(*Iter2, BecameZero);
  573. if (BecameZero) {
  574. ElementListIter IterTmp = Iter1;
  575. ++Iter1;
  576. Elements.erase(IterTmp);
  577. } else {
  578. ++Iter1;
  579. }
  580. ++Iter2;
  581. } else {
  582. ++Iter1;
  583. }
  584. }
  585. CurrElementIter = Elements.begin();
  586. return changed;
  587. }
  588. bool intersectWithComplement(const SparseBitVector<ElementSize>* RHS) const {
  589. return intersectWithComplement(*RHS);
  590. }
  591. // Three argument version of intersectWithComplement.
  592. // Result of RHS1 & ~RHS2 is stored into this bitmap.
  593. void intersectWithComplement(
  594. const SparseBitVector<ElementSize>& RHS1,
  595. const SparseBitVector<ElementSize>& RHS2) {
  596. if (this == &RHS1) {
  597. intersectWithComplement(RHS2);
  598. return;
  599. } else if (this == &RHS2) {
  600. SparseBitVector RHS2Copy(RHS2);
  601. intersectWithComplement(RHS1, RHS2Copy);
  602. return;
  603. }
  604. Elements.clear();
  605. CurrElementIter = Elements.begin();
  606. ElementListConstIter Iter1 = RHS1.Elements.begin();
  607. ElementListConstIter Iter2 = RHS2.Elements.begin();
  608. // If RHS1 is empty, we are done
  609. // If RHS2 is empty, we still have to copy RHS1
  610. if (RHS1.Elements.empty())
  611. return;
  612. // Loop through, intersecting as we go, erasing elements when necessary.
  613. while (Iter2 != RHS2.Elements.end()) {
  614. if (Iter1 == RHS1.Elements.end())
  615. return;
  616. if (Iter1->index() > Iter2->index()) {
  617. ++Iter2;
  618. } else if (Iter1->index() == Iter2->index()) {
  619. bool BecameZero = false;
  620. Elements.emplace_back(Iter1->index());
  621. Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero);
  622. if (BecameZero)
  623. Elements.pop_back();
  624. ++Iter1;
  625. ++Iter2;
  626. } else {
  627. Elements.push_back(*Iter1++);
  628. }
  629. }
  630. // copy the remaining elements
  631. std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements));
  632. }
  633. void intersectWithComplement(
  634. const SparseBitVector<ElementSize>* RHS1,
  635. const SparseBitVector<ElementSize>* RHS2) {
  636. intersectWithComplement(*RHS1, *RHS2);
  637. }
  638. bool intersects(const SparseBitVector<ElementSize>* RHS) const {
  639. return intersects(*RHS);
  640. }
  641. // Return true if we share any bits in common with RHS
  642. bool intersects(const SparseBitVector<ElementSize>& RHS) const {
  643. ElementListConstIter Iter1 = Elements.begin();
  644. ElementListConstIter Iter2 = RHS.Elements.begin();
  645. // Check if both bitmaps are empty.
  646. if (Elements.empty() && RHS.Elements.empty())
  647. return false;
  648. // Loop through, intersecting stopping when we hit bits in common.
  649. while (Iter2 != RHS.Elements.end()) {
  650. if (Iter1 == Elements.end())
  651. return false;
  652. if (Iter1->index() > Iter2->index()) {
  653. ++Iter2;
  654. } else if (Iter1->index() == Iter2->index()) {
  655. if (Iter1->intersects(*Iter2))
  656. return true;
  657. ++Iter1;
  658. ++Iter2;
  659. } else {
  660. ++Iter1;
  661. }
  662. }
  663. return false;
  664. }
  665. // Return true iff all bits set in this SparseBitVector are
  666. // also set in RHS.
  667. bool contains(const SparseBitVector<ElementSize>& RHS) const {
  668. SparseBitVector<ElementSize> Result(*this);
  669. Result &= RHS;
  670. return (Result == RHS);
  671. }
  672. // Return the first set bit in the bitmap. Return -1 if no bits are set.
  673. int find_first() const {
  674. if (Elements.empty())
  675. return -1;
  676. const SparseBitVectorElement<ElementSize>& First = *(Elements.begin());
  677. return (First.index() * ElementSize) + First.find_first();
  678. }
  679. // Return the last set bit in the bitmap. Return -1 if no bits are set.
  680. int find_last() const {
  681. if (Elements.empty())
  682. return -1;
  683. const SparseBitVectorElement<ElementSize>& Last = *(Elements.rbegin());
  684. return (Last.index() * ElementSize) + Last.find_last();
  685. }
  686. // Return true if the SparseBitVector is empty
  687. bool empty() const {
  688. return Elements.empty();
  689. }
  690. unsigned count() const {
  691. unsigned BitCount = 0;
  692. for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end();
  693. ++Iter)
  694. BitCount += Iter->count();
  695. return BitCount;
  696. }
  697. iterator begin() const {
  698. return iterator(this);
  699. }
  700. iterator end() const {
  701. return iterator(this, true);
  702. }
  703. };
  704. // Convenience functions to allow Or and And without dereferencing in the user
  705. // code.
  706. template <unsigned ElementSize>
  707. inline bool operator|=(
  708. SparseBitVector<ElementSize>& LHS,
  709. const SparseBitVector<ElementSize>* RHS) {
  710. return LHS |= *RHS;
  711. }
  712. template <unsigned ElementSize>
  713. inline bool operator|=(
  714. SparseBitVector<ElementSize>* LHS,
  715. const SparseBitVector<ElementSize>& RHS) {
  716. return LHS->operator|=(RHS);
  717. }
  718. template <unsigned ElementSize>
  719. inline bool operator&=(
  720. SparseBitVector<ElementSize>* LHS,
  721. const SparseBitVector<ElementSize>& RHS) {
  722. return LHS->operator&=(RHS);
  723. }
  724. template <unsigned ElementSize>
  725. inline bool operator&=(
  726. SparseBitVector<ElementSize>& LHS,
  727. const SparseBitVector<ElementSize>* RHS) {
  728. return LHS &= *RHS;
  729. }
  730. // Convenience functions for infix union, intersection, difference operators.
  731. template <unsigned ElementSize>
  732. inline SparseBitVector<ElementSize> operator|(
  733. const SparseBitVector<ElementSize>& LHS,
  734. const SparseBitVector<ElementSize>& RHS) {
  735. SparseBitVector<ElementSize> Result(LHS);
  736. Result |= RHS;
  737. return Result;
  738. }
  739. template <unsigned ElementSize>
  740. inline SparseBitVector<ElementSize> operator&(
  741. const SparseBitVector<ElementSize>& LHS,
  742. const SparseBitVector<ElementSize>& RHS) {
  743. SparseBitVector<ElementSize> Result(LHS);
  744. Result &= RHS;
  745. return Result;
  746. }
  747. template <unsigned ElementSize>
  748. inline SparseBitVector<ElementSize> operator-(
  749. const SparseBitVector<ElementSize>& LHS,
  750. const SparseBitVector<ElementSize>& RHS) {
  751. SparseBitVector<ElementSize> Result;
  752. Result.intersectWithComplement(LHS, RHS);
  753. return Result;
  754. }
  755. template <unsigned ElementSize>
  756. std::ostream& operator<<(
  757. std::ostream& stream,
  758. const SparseBitVector<ElementSize>& vec) {
  759. bool first = true;
  760. stream << "{";
  761. for (auto el : vec) {
  762. if (first) {
  763. first = false;
  764. } else {
  765. stream << ", ";
  766. }
  767. stream << el;
  768. }
  769. stream << "}";
  770. return stream;
  771. }
  772. } // end namespace c10
  773. C10_CLANG_DIAGNOSTIC_POP()