Index: include/llvm/ADT/DenseMap.h =================================================================== --- include/llvm/ADT/DenseMap.h +++ include/llvm/ADT/DenseMap.h @@ -19,6 +19,7 @@ #include "llvm/Support/AlignOf.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ReverseIteration.h" #include "llvm/Support/type_traits.h" #include #include @@ -67,18 +68,27 @@ DenseMapIterator; inline iterator begin() { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) + return makeIterator(getBucketsEnd() - 1, getBuckets(), *this); +#endif // When the map is empty, avoid the overhead of AdvancePastEmptyBuckets(). - return empty() ? end() : iterator(getBuckets(), getBucketsEnd(), *this); + return empty() ? end() + : makeIterator(getBuckets(), getBucketsEnd(), *this); } inline iterator end() { - return iterator(getBucketsEnd(), getBucketsEnd(), *this, true); + return makeIterator(getBucketsEnd(), getBucketsEnd(), *this, true); } inline const_iterator begin() const { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) + return makeConstIterator(getBucketsEnd() - 1, getBuckets(), *this); +#endif return empty() ? end() - : const_iterator(getBuckets(), getBucketsEnd(), *this); + : makeConstIterator(getBuckets(), getBucketsEnd(), *this); } inline const_iterator end() const { - return const_iterator(getBucketsEnd(), getBucketsEnd(), *this, true); + return makeConstIterator(getBucketsEnd(), getBucketsEnd(), *this, true); } LLVM_NODISCARD bool empty() const { @@ -131,13 +141,13 @@ iterator find(const_arg_type_t Val) { BucketT *TheBucket; if (LookupBucketFor(Val, TheBucket)) - return iterator(TheBucket, getBucketsEnd(), *this, true); + return makeIterator(TheBucket, getBucketsEnd(), *this, true); return end(); } const_iterator find(const_arg_type_t Val) const { const BucketT *TheBucket; if (LookupBucketFor(Val, TheBucket)) - return const_iterator(TheBucket, getBucketsEnd(), *this, true); + return makeConstIterator(TheBucket, getBucketsEnd(), *this, true); return end(); } @@ -150,14 +160,14 @@ iterator find_as(const LookupKeyT &Val) { BucketT *TheBucket; if (LookupBucketFor(Val, TheBucket)) - return iterator(TheBucket, getBucketsEnd(), *this, true); + return makeIterator(TheBucket, getBucketsEnd(), *this, true); return end(); } template const_iterator find_as(const LookupKeyT &Val) const { const BucketT *TheBucket; if (LookupBucketFor(Val, TheBucket)) - return const_iterator(TheBucket, getBucketsEnd(), *this, true); + return makeConstIterator(TheBucket, getBucketsEnd(), *this, true); return end(); } @@ -191,14 +201,16 @@ std::pair try_emplace(KeyT &&Key, Ts &&... Args) { BucketT *TheBucket; if (LookupBucketFor(Key, TheBucket)) - return std::make_pair(iterator(TheBucket, getBucketsEnd(), *this, true), - false); // Already in map. + return std::make_pair( + makeIterator(TheBucket, getBucketsEnd(), *this, true), + false); // Already in map. // Otherwise, insert the new element. TheBucket = InsertIntoBucket(TheBucket, std::move(Key), std::forward(Args)...); - return std::make_pair(iterator(TheBucket, getBucketsEnd(), *this, true), - true); + return std::make_pair( + makeIterator(TheBucket, getBucketsEnd(), *this, true), + true); } // Inserts key,value pair into the map if the key isn't already in the map. @@ -208,13 +220,15 @@ std::pair try_emplace(const KeyT &Key, Ts &&... Args) { BucketT *TheBucket; if (LookupBucketFor(Key, TheBucket)) - return std::make_pair(iterator(TheBucket, getBucketsEnd(), *this, true), - false); // Already in map. + return std::make_pair( + makeIterator(TheBucket, getBucketsEnd(), *this, true), + false); // Already in map. // Otherwise, insert the new element. TheBucket = InsertIntoBucket(TheBucket, Key, std::forward(Args)...); - return std::make_pair(iterator(TheBucket, getBucketsEnd(), *this, true), - true); + return std::make_pair( + makeIterator(TheBucket, getBucketsEnd(), *this, true), + true); } /// Alternate version of insert() which allows a different, and possibly @@ -227,14 +241,16 @@ const LookupKeyT &Val) { BucketT *TheBucket; if (LookupBucketFor(Val, TheBucket)) - return std::make_pair(iterator(TheBucket, getBucketsEnd(), *this, true), - false); // Already in map. + return std::make_pair( + makeIterator(TheBucket, getBucketsEnd(), *this, true), + false); // Already in map. // Otherwise, insert the new element. TheBucket = InsertIntoBucketWithLookup(TheBucket, std::move(KV.first), std::move(KV.second), Val); - return std::make_pair(iterator(TheBucket, getBucketsEnd(), *this, true), - true); + return std::make_pair( + makeIterator(TheBucket, getBucketsEnd(), *this, true), + true); } /// insert - Range insertion of pairs. @@ -405,6 +421,30 @@ } private: + iterator makeIterator(BucketT *P, BucketT *E, + DebugEpochBase &Epoch, + bool NoAdvance=false) { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) { + BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1; + return iterator(B, E, Epoch, NoAdvance); + } +#endif + return iterator(P, E, Epoch, NoAdvance); + } + + const_iterator makeConstIterator(const BucketT *P, const BucketT *E, + const DebugEpochBase &Epoch, + const bool NoAdvance=false) const { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) { + const BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1; + return const_iterator(B, E, Epoch, NoAdvance); + } +#endif + return const_iterator(P, E, Epoch, NoAdvance); + } + unsigned getNumEntries() const { return static_cast(this)->getNumEntries(); } @@ -1089,6 +1129,12 @@ bool NoAdvance = false) : DebugEpochBase::HandleBase(&Epoch), Ptr(Pos), End(E) { assert(isHandleInSync() && "invalid construction!"); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) { + if (!NoAdvance) RetreatPastEmptyBuckets(); + return; + } +#endif if (!NoAdvance) AdvancePastEmptyBuckets(); } @@ -1103,10 +1149,18 @@ reference operator*() const { assert(isHandleInSync() && "invalid iterator access!"); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) + return Ptr[-1]; +#endif return *Ptr; } pointer operator->() const { assert(isHandleInSync() && "invalid iterator access!"); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) + return &(Ptr[-1]); +#endif return Ptr; } @@ -1127,6 +1181,13 @@ inline DenseMapIterator& operator++() { // Preincrement assert(isHandleInSync() && "invalid iterator access!"); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (shouldReverseIterate()) { + --Ptr; + RetreatPastEmptyBuckets(); + return *this; + } +#endif ++Ptr; AdvancePastEmptyBuckets(); return *this; @@ -1138,6 +1199,7 @@ private: void AdvancePastEmptyBuckets() { + assert(Ptr <= End); const KeyT Empty = KeyInfoT::getEmptyKey(); const KeyT Tombstone = KeyInfoT::getTombstoneKey(); @@ -1145,6 +1207,17 @@ KeyInfoT::isEqual(Ptr->getFirst(), Tombstone))) ++Ptr; } +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + void RetreatPastEmptyBuckets() { + assert(Ptr >= End); + const KeyT Empty = KeyInfoT::getEmptyKey(); + const KeyT Tombstone = KeyInfoT::getTombstoneKey(); + + while (Ptr != End && (KeyInfoT::isEqual(Ptr[-1].getFirst(), Empty) || + KeyInfoT::isEqual(Ptr[-1].getFirst(), Tombstone))) + --Ptr; + } +#endif }; template Index: include/llvm/ADT/SmallPtrSet.h =================================================================== --- include/llvm/ADT/SmallPtrSet.h +++ include/llvm/ADT/SmallPtrSet.h @@ -225,7 +225,7 @@ explicit SmallPtrSetIteratorImpl(const void *const *BP, const void*const *E) : Bucket(BP), End(E) { #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (ReverseIterate::value) { + if (shouldReverseIterate()) { RetreatIfNotValid(); return; } @@ -282,7 +282,7 @@ const PtrTy operator*() const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (ReverseIterate::value) { + if (shouldReverseIterate()) { assert(Bucket > End); return PtrTraits::getFromVoidPointer(const_cast(Bucket[-1])); } @@ -293,7 +293,7 @@ inline SmallPtrSetIterator& operator++() { // Preincrement #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (ReverseIterate::value) { + if (shouldReverseIterate()) { --Bucket; RetreatIfNotValid(); return *this; @@ -397,7 +397,7 @@ iterator begin() const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (ReverseIterate::value) + if (shouldReverseIterate()) return makeIterator(EndPointer() - 1); #endif return makeIterator(CurArray); @@ -408,7 +408,7 @@ /// Create an iterator that dereferences to same place as the given pointer. iterator makeIterator(const void *const *P) const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (ReverseIterate::value) + if (shouldReverseIterate()) return iterator(P == EndPointer() ? CurArray : P + 1, CurArray); #endif return iterator(P, EndPointer()); Index: include/llvm/Support/ReverseIteration.h =================================================================== --- include/llvm/Support/ReverseIteration.h +++ include/llvm/Support/ReverseIteration.h @@ -2,16 +2,30 @@ #define LLVM_SUPPORT_REVERSEITERATION_H #include "llvm/Config/abi-breaking.h" +#include +#include namespace llvm { #if LLVM_ENABLE_ABI_BREAKING_CHECKS + template struct ReverseIterate { static bool value; }; #if LLVM_ENABLE_REVERSE_ITERATION template bool ReverseIterate::value = true; #else template bool ReverseIterate::value = false; #endif -#endif + +// For containers like maps which compute a hash based on the key, we reverse +// iterate only if the key is pointer-like. For other containers like sets +// which do not compute a hash, std::is_pointer::value will always +// return true. So the decision whether to reverse iterate will depend only on +// ReverseIterate::value. +template +bool shouldReverseIterate() { + return ReverseIterate::value && + std::is_pointer::value; } #endif +} +#endif Index: unittests/ADT/ReverseIterationTest.cpp =================================================================== --- unittests/ADT/ReverseIterationTest.cpp +++ unittests/ADT/ReverseIterationTest.cpp @@ -11,14 +11,90 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "gtest/gtest.h" #if LLVM_ENABLE_ABI_BREAKING_CHECKS using namespace llvm; -TEST(ReverseIterationTest, SmallPtrSetTest) { +TEST(ReverseIterationTest, DenseMapTest1) { + DenseMap Map; + void *Keys[] = { (void*)0x1, (void*)0x2, (void*)0x3, (void*)0x4 }; + void *ReverseKeys[] = { (void*)0x4, (void*)0x3, (void*)0x2, (void*)0x1 }; + + for (auto *Key: Keys) + Map[Key] = 0; + + // Check forward iteration. + ReverseIterate::value = false; + for (const auto &Tuple : zip(Map, Keys)) + ASSERT_EQ(std::get<0>(Tuple).first, std::get<1>(Tuple)); + + // Check operator++ (post-increment) in forward iteration. + int i = 0; + for (auto begin = Map.begin(), end = Map.end(); + begin != end; i++) { + ASSERT_EQ(begin->first, Keys[i]); + begin++; + } + + // Check reverse iteration. + ReverseIterate::value = true; + for (const auto &Tuple : zip(Map, ReverseKeys)) + ASSERT_EQ(std::get<0>(Tuple).first, std::get<1>(Tuple)); + + // Check operator++ (post-increment) in reverse iteration. + i = 0; + for (auto begin = Map.begin(), end = Map.end(); + begin != end; i++) { + ASSERT_EQ(begin->first, ReverseKeys[i]); + begin++; + } +} + +TEST(ReverseIterationTest, DenseMapTest2) { + // For a DenseMap with non-pointer keys, forward iteration equals reverse + // iteration. + DenseMap Map; + int Keys[] = { 1, 2, 3, 4 }; + + for (auto Key: Keys) + Map[Key] = 0; + int IterKeys[4]; + int i = 0; + for (auto Key : Map) + IterKeys[i++] = Key.first; + + // Check forward iteration. + ReverseIterate::value = false; + for (const auto &Tuple : zip(Map, IterKeys)) + ASSERT_EQ(std::get<0>(Tuple).first, std::get<1>(Tuple)); + + // Check operator++ (post-increment) in forward iteration. + i = 0; + for (auto begin = Map.begin(), end = Map.end(); + begin != end; i++) { + ASSERT_EQ(begin->first, IterKeys[i]); + begin++; + } + + // Check reverse iteration. + ReverseIterate::value = true; + for (const auto &Tuple : zip(Map, IterKeys)) + ASSERT_EQ(std::get<0>(Tuple).first, std::get<1>(Tuple)); + + // Check operator++ (post-increment) in reverse iteration. + i = 0; + for (auto begin = Map.begin(), end = Map.end(); + begin != end; i++) { + ASSERT_EQ(begin->first, IterKeys[i]); + begin++; + } +} + +TEST(ReverseIterationTest, SmallPtrSetTest) { SmallPtrSet Set; void *Ptrs[] = { (void*)0x1, (void*)0x2, (void*)0x3, (void*)0x4 }; void *ReversePtrs[] = { (void*)0x4, (void*)0x3, (void*)0x2, (void*)0x1 }; @@ -42,11 +118,10 @@ for (const auto &Tuple : zip(Set, ReversePtrs)) ASSERT_EQ(std::get<0>(Tuple), std::get<1>(Tuple)); - // Check operator++ (post-increment) in reverse iteration. + // Check operator++ (post-increment) in reverse iteration. i = 0; for (auto begin = Set.begin(), end = Set.end(); begin != end; i++) ASSERT_EQ(*begin++, ReversePtrs[i]); - } #endif