diff --git a/llvm/include/llvm/ADT/TrieRawHashMap.h b/llvm/include/llvm/ADT/TrieRawHashMap.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/TrieRawHashMap.h @@ -0,0 +1,404 @@ +//===- TrieRawHashMap.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_TRIERAWHASHMAP_H +#define LLVM_ADT_TRIERAWHASHMAP_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include + +namespace llvm { + +class raw_ostream; + +/// TrieRawHashMap - is a lock-free thread-safe trie that is can be used to +/// store/index data based on a hash value. It can be customized to work with +/// any hash algorithm or store any data. +/// +/// Data structure: +/// Data node stored in the Trie contains both hash and data: +/// struct { +/// HashT Hash; +/// DataT Data; +/// }; +/// +/// Data is stored/indexed via a prefix tree, where each node in the tree can be +/// either the root, a sub-trie or a data node. Assuming a 4-bit hash and two +/// data objects {0001, A} and {0100, B}, it can be stored in a trie +/// (assuming Root has 2 bits, SubTrie has 1 bit): +/// +--------+ +/// |Root[00]| -> {0001, A} +/// | [01]| -> {0100, B} +/// | [10]| (empty) +/// | [11]| (empty) +/// +--------+ +/// +/// Inserting a new object {0010, C} will result in: +/// +--------+ +----------+ +/// |Root[00]| -> |SubTrie[0]| -> {0001, A} +/// | | | [1]| -> {0010, C} +/// | | +----------+ +/// | [01]| -> {0100, B} +/// | [10]| (empty) +/// | [11]| (empty) +/// +--------+ +/// Note object A is sinked down to a sub-trie during the insertion. All the +/// nodes are inserted through compare-exchange to ensure thread-safe and +/// lock-free. +/// +/// To find an object in the trie, walk the tree with prefix of the hash until +/// the data node is found. Then the hash is compared with the hash stored in +/// the data node to see if the is the same object. +/// +/// Hash collision is not allowed so it is recommanded to use trie with a +/// "strong" hashing algorithm. A well-distributed hash can also result in +/// better performance and memory usage. +/// +/// It currently does not support iteration and deletion. + +/// Base class for a lock-free thread-safe hash-mapped trie. +class ThreadSafeTrieRawHashMapBase { +public: + static constexpr size_t TrieContentBaseSize = 4; + static constexpr size_t DefaultNumRootBits = 6; + static constexpr size_t DefaultNumSubtrieBits = 4; + +private: + template struct AllocValueType { + char Base[TrieContentBaseSize]; + std::aligned_union_t Content; + }; + +protected: + template + static constexpr size_t DefaultContentAllocSize = sizeof(AllocValueType); + + template + static constexpr size_t DefaultContentAllocAlign = alignof(AllocValueType); + + template + static constexpr size_t DefaultContentOffset = + offsetof(AllocValueType, Content); + +public: + void operator delete(void *Ptr) { ::free(Ptr); } + + LLVM_DUMP_METHOD void dump() const; + void print(raw_ostream &OS) const; + +protected: + /// Result of a lookup. Suitable for an insertion hint. Maybe could be + /// expanded into an iterator of sorts, but likely not useful (visiting + /// everything in the trie should probably be done some way other than + /// through an iterator pattern). + class PointerBase { + protected: + void *get() const { return I == -2u ? P : nullptr; } + + public: + PointerBase() noexcept = default; + PointerBase(PointerBase &&) = default; + PointerBase(const PointerBase &) = default; + PointerBase &operator=(PointerBase &&) = default; + PointerBase &operator=(const PointerBase &) = default; + + private: + friend class ThreadSafeTrieRawHashMapBase; + explicit PointerBase(void *Content) : P(Content), I(-2u) {} + PointerBase(void *P, unsigned I, unsigned B) : P(P), I(I), B(B) {} + + bool isHint() const { return I != -1u && I != -2u; } + + void *P = nullptr; + unsigned I = -1u; + unsigned B = 0; + }; + + /// Find the stored content with hash. + PointerBase find(ArrayRef Hash) const; + + /// Insert and return the stored content. + /// If the hash is already in the trie, it returns false. + std::pair + insert(PointerBase Hint, ArrayRef Hash, + function_ref Hash)> + Constructor); + + ThreadSafeTrieRawHashMapBase() = delete; + + ThreadSafeTrieRawHashMapBase( + size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, + std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt); + + /// Destructor, which asserts if there's anything to do. Subclasses should + /// call \a destroyImpl(). + /// + /// \pre \a destroyImpl() was already called. + ~ThreadSafeTrieRawHashMapBase(); + void destroyImpl(function_ref Destructor); + + ThreadSafeTrieRawHashMapBase(ThreadSafeTrieRawHashMapBase &&RHS); + + // Move assignment can be implemented in a thread-safe way if NumRootBits and + // NumSubtrieBits are stored inside the Root. + ThreadSafeTrieRawHashMapBase & + operator=(ThreadSafeTrieRawHashMapBase &&RHS) = delete; + + // No copy. + ThreadSafeTrieRawHashMapBase(const ThreadSafeTrieRawHashMapBase &) = delete; + ThreadSafeTrieRawHashMapBase & + operator=(const ThreadSafeTrieRawHashMapBase &) = delete; + + // Debug functions. Implementation details and not guaranteed to be + // thread-safe. + PointerBase getRoot() const; + unsigned getStartBit(PointerBase P) const; + unsigned getNumBits(PointerBase P) const; + unsigned getNumSlotUsed(PointerBase P) const; + std::string getTriePrefixAsString(PointerBase P) const; + unsigned getNumTries() const; + // Visit next trie in the allocation chain. + PointerBase getNextTrie(PointerBase P) const; + +private: + friend class TrieRawHashMapTestHelper; + const unsigned short ContentAllocSize; + const unsigned short ContentAllocAlign; + const unsigned short ContentOffset; + unsigned short NumRootBits; + unsigned short NumSubtrieBits; + struct ImplType; + // ImplPtr is owned by ThreadSafeTrieRawHashMapBase and needs to be freed in + // destoryImpl. + std::atomic ImplPtr; + ImplType &getOrCreateImpl(); + ImplType *getImpl() const; +}; + +/// Lock-free thread-safe hash-mapped trie. +template +class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase { +public: + using HashT = std::array; + + class LazyValueConstructor; + struct value_type { + const HashT Hash; + T Data; + + value_type(value_type &&) = default; + value_type(const value_type &) = default; + + value_type(ArrayRef Hash, const T &Data) + : Hash(copyHash(Hash)), Data(Data) {} + value_type(ArrayRef Hash, T &&Data) + : Hash(copyHash(Hash)), Data(std::move(Data)) {} + + private: + friend class LazyValueConstructor; + + struct EmplaceTag {}; + template + value_type(ArrayRef Hash, EmplaceTag, ArgsT &&...Args) + : Hash(copyHash(Hash)), Data(std::forward(Args)...) {} + + static HashT copyHash(ArrayRef HashRef) { + HashT Hash; + std::copy(HashRef.begin(), HashRef.end(), Hash.data()); + return Hash; + } + }; + + using ThreadSafeTrieRawHashMapBase::operator delete; + using HashType = HashT; + + using ThreadSafeTrieRawHashMapBase::dump; + using ThreadSafeTrieRawHashMapBase::print; + +private: + template class PointerImpl : PointerBase { + friend class ThreadSafeTrieRawHashMap; + + ValueT *get() const { + if (void *B = PointerBase::get()) + return reinterpret_cast(B); + return nullptr; + } + + public: + ValueT &operator*() const { + assert(get()); + return *get(); + } + ValueT *operator->() const { + assert(get()); + return get(); + } + explicit operator bool() const { return get(); } + + PointerImpl() = default; + PointerImpl(PointerImpl &&) = default; + PointerImpl(const PointerImpl &) = default; + PointerImpl &operator=(PointerImpl &&) = default; + PointerImpl &operator=(const PointerImpl &) = default; + + protected: + PointerImpl(PointerBase Result) : PointerBase(Result) {} + }; + +public: + class pointer; + class const_pointer; + class pointer : public PointerImpl { + friend class ThreadSafeTrieRawHashMap; + friend class const_pointer; + + public: + pointer() = default; + pointer(pointer &&) = default; + pointer(const pointer &) = default; + pointer &operator=(pointer &&) = default; + pointer &operator=(const pointer &) = default; + + private: + pointer(PointerBase Result) : pointer::PointerImpl(Result) {} + }; + + class const_pointer : public PointerImpl { + friend class ThreadSafeTrieRawHashMap; + + public: + const_pointer() = default; + const_pointer(const_pointer &&) = default; + const_pointer(const const_pointer &) = default; + const_pointer &operator=(const_pointer &&) = default; + const_pointer &operator=(const const_pointer &) = default; + + const_pointer(const pointer &P) : const_pointer::PointerImpl(P) {} + + private: + const_pointer(PointerBase Result) : const_pointer::PointerImpl(Result) {} + }; + + class LazyValueConstructor { + public: + value_type &operator()(T &&RHS) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) value_type(Hash, std::move(RHS))); + } + value_type &operator()(const T &RHS) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) value_type(Hash, RHS)); + } + template value_type &emplace(ArgsT &&...Args) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) + value_type(Hash, typename value_type::EmplaceTag{}, + std::forward(Args)...)); + } + + LazyValueConstructor(LazyValueConstructor &&RHS) + : Mem(RHS.Mem), Result(RHS.Result), Hash(RHS.Hash) { + RHS.Mem = nullptr; // Moved away, cannot call. + } + ~LazyValueConstructor() { assert(!Mem && "Constructor never called!"); } + + private: + value_type &assign(value_type *V) { + Mem = nullptr; + Result = V; + return *V; + } + friend class ThreadSafeTrieRawHashMap; + LazyValueConstructor() = delete; + LazyValueConstructor(void *Mem, value_type *&Result, ArrayRef Hash) + : Mem(Mem), Result(Result), Hash(Hash) { + assert(Hash.size() == sizeof(HashT) && "Invalid hash"); + assert(Mem && "Invalid memory for construction"); + } + void *Mem; + value_type *&Result; + ArrayRef Hash; + }; + + /// Insert with a hint. Default-constructed hint will work, but it's + /// recommended to start with a lookup to avoid overhead in object creation + /// if it already exists. + /// Return false if the hash is already in the map. + std::pair + insertLazy(const_pointer Hint, ArrayRef Hash, + function_ref OnConstruct) { + auto Result = ThreadSafeTrieRawHashMapBase::insert( + Hint, Hash, [&](void *Mem, ArrayRef Hash) { + value_type *Result = nullptr; + OnConstruct(LazyValueConstructor(Mem, Result, Hash)); + return Result->Hash.data(); + }); + return {pointer(Result.first), Result.second}; + } + + std::pair + insertLazy(ArrayRef Hash, + function_ref OnConstruct) { + return insertLazy(const_pointer(), Hash, OnConstruct); + } + + std::pair insert(const_pointer Hint, value_type &&HashedData) { + return insertLazy(Hint, HashedData.Hash, [&](LazyValueConstructor C) { + C(std::move(HashedData.Data)); + }); + } + + std::pair insert(const_pointer Hint, + const value_type &HashedData) { + return insertLazy(Hint, HashedData.Hash, + [&](LazyValueConstructor C) { C(HashedData.Data); }); + } + + pointer find(ArrayRef Hash) { + assert(Hash.size() == std::tuple_size::value); + return ThreadSafeTrieRawHashMapBase::find(Hash); + } + + const_pointer find(ArrayRef Hash) const { + assert(Hash.size() == std::tuple_size::value); + return ThreadSafeTrieRawHashMapBase::find(Hash); + } + + ThreadSafeTrieRawHashMap(std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt) + : ThreadSafeTrieRawHashMapBase(DefaultContentAllocSize, + DefaultContentAllocAlign, + DefaultContentOffset, + NumRootBits, NumSubtrieBits) {} + + ~ThreadSafeTrieRawHashMap() { + if constexpr (std::is_trivially_destructible::value) + this->destroyImpl(nullptr); + else + this->destroyImpl( + [](void *P) { static_cast(P)->~value_type(); }); + } + + // Move constructor okay. + ThreadSafeTrieRawHashMap(ThreadSafeTrieRawHashMap &&) = default; + + // No move assignment or any copy. + ThreadSafeTrieRawHashMap &operator=(ThreadSafeTrieRawHashMap &&) = delete; + ThreadSafeTrieRawHashMap(const ThreadSafeTrieRawHashMap &) = delete; + ThreadSafeTrieRawHashMap & + operator=(const ThreadSafeTrieRawHashMap &) = delete; +}; + +} // namespace llvm + +#endif // LLVM_ADT_TRIERAWHASHMAP_H diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -225,6 +225,7 @@ TimeProfiler.cpp Timer.cpp ToolOutputFile.cpp + TrieRawHashMap.cpp TrigramIndex.cpp Twine.cpp TypeSize.cpp diff --git a/llvm/lib/Support/TrieHashIndexGenerator.h b/llvm/lib/Support/TrieHashIndexGenerator.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Support/TrieHashIndexGenerator.h @@ -0,0 +1,89 @@ +//===- TrieHashIndexGenerator.h ---------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H +#define LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H + +#include "llvm/ADT/ArrayRef.h" +#include + +namespace llvm { + +struct IndexGenerator { + size_t NumRootBits; + size_t NumSubtrieBits; + ArrayRef Bytes; + std::optional StartBit = std::nullopt; + + size_t getNumBits() const { + assert(StartBit); + size_t TotalNumBits = Bytes.size() * 8; + assert(*StartBit <= TotalNumBits); + return std::min(*StartBit ? NumSubtrieBits : NumRootBits, + TotalNumBits - *StartBit); + } + size_t next() { + size_t Index; + if (!StartBit) { + StartBit = 0; + Index = getIndex(Bytes, *StartBit, NumRootBits); + } else { + *StartBit += *StartBit ? NumSubtrieBits : NumRootBits; + assert((*StartBit - NumRootBits) % NumSubtrieBits == 0); + Index = getIndex(Bytes, *StartBit, NumSubtrieBits); + } + return Index; + } + + size_t hint(unsigned Index, unsigned Bit) { + assert(Index >= 0); + assert(Bit < Bytes.size() * 8); + assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0); + StartBit = Bit; + return Index; + } + + size_t getCollidingBits(ArrayRef CollidingBits) const { + assert(StartBit); + return getIndex(CollidingBits, *StartBit, NumSubtrieBits); + } + + static size_t getIndex(ArrayRef Bytes, size_t StartBit, + size_t NumBits) { + assert(StartBit < Bytes.size() * 8); + + Bytes = Bytes.drop_front(StartBit / 8u); + StartBit %= 8u; + size_t Index = 0; + for (uint8_t Byte : Bytes) { + size_t ByteStart = 0, ByteEnd = 8; + if (StartBit) { + ByteStart = StartBit; + Byte &= (1u << (8 - StartBit)) - 1u; + StartBit = 0; + } + size_t CurrentNumBits = ByteEnd - ByteStart; + if (CurrentNumBits > NumBits) { + Byte >>= CurrentNumBits - NumBits; + CurrentNumBits = NumBits; + } + Index <<= CurrentNumBits; + Index |= Byte & ((1u << CurrentNumBits) - 1u); + + assert(NumBits >= CurrentNumBits); + NumBits -= CurrentNumBits; + if (!NumBits) + break; + } + return Index; + } +}; + +} // namespace llvm + +#endif // LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Support/TrieRawHashMap.cpp @@ -0,0 +1,494 @@ +//===- TrieRawHashMap.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/TrieRawHashMap.h" +#include "TrieHashIndexGenerator.h" +#include "llvm/ADT/LazyAtomicPointer.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ThreadSafeAllocator.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace llvm; + +namespace { +struct TrieNode { + const bool IsSubtrie = false; + + TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {} + + static void *operator new(size_t Size) { return ::malloc(Size); } + void operator delete(void *Ptr) { ::free(Ptr); } +}; + +struct TrieContent final : public TrieNode { + const uint8_t ContentOffset; + const uint8_t HashSize; + const uint8_t HashOffset; + + void *getValuePointer() const { + auto Content = reinterpret_cast(this) + ContentOffset; + return const_cast(Content); + } + + ArrayRef getHash() const { + auto *Begin = reinterpret_cast(this) + HashOffset; + return makeArrayRef(Begin, Begin + HashSize); + } + + TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset) + : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset), + HashSize(HashSize), HashOffset(HashOffset) {} +}; +static_assert(sizeof(TrieContent) == + ThreadSafeTrieRawHashMapBase::TrieContentBaseSize, + "Check header assumption!"); + +class TrieSubtrie final : public TrieNode { +public: + TrieNode *get(size_t I) const { return Slots[I].load(); } + + TrieSubtrie * + sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, + function_ref)> Saver); + + static std::unique_ptr create(size_t StartBit, size_t NumBits); + + explicit TrieSubtrie(size_t StartBit, size_t NumBits); + +private: + // FIXME: Use a bitset to speed up access: + // + // std::array, NumSlots/64> IsSet; + // + // This will avoid needing to visit sparsely filled slots in + // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial + // destructor. + // + // It would also greatly speed up iteration, if we add that some day, and + // allow get() to return one level sooner. + // + // This would be the algorithm for updating IsSet (after updating Slots): + // + // std::atomic &Bits = IsSet[I.High]; + // const uint64_t NewBit = 1ULL << I.Low; + // uint64_t Old = 0; + // while (!Bits.compare_exchange_weak(Old, Old | NewBit)) + // ; + + // For debugging. + unsigned StartBit = 0; + unsigned NumBits = 0; + friend class llvm::ThreadSafeTrieRawHashMapBase; + +public: + /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie. + std::atomic Next; + + /// The (co-allocated) slots of the subtrie. + MutableArrayRef> Slots; +}; +} // end namespace + +namespace llvm { +template <> struct isa_impl { + static inline bool doit(const TrieNode &TN) { return !TN.IsSubtrie; } +}; +template <> struct isa_impl { + static inline bool doit(const TrieNode &TN) { return TN.IsSubtrie; } +}; +} // end namespace llvm + +static size_t getTrieTailSize(size_t StartBit, size_t NumBits) { + assert(NumBits < 20 && "Tries should have fewer than ~1M slots"); + return sizeof(TrieNode *) * (1u << NumBits); +} + +std::unique_ptr TrieSubtrie::create(size_t StartBit, + size_t NumBits) { + size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits); + void *Memory = ::malloc(Size); + TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits); + return std::unique_ptr(S); +} + +TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits) + : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Next(nullptr), + Slots(reinterpret_cast *>( + reinterpret_cast(this) + sizeof(TrieSubtrie)), + (1u << NumBits)) { + for (auto *I = Slots.begin(), *E = Slots.end(); I != E; ++I) + new (I) LazyAtomicPointer(nullptr); + + static_assert( + std::is_trivially_destructible>::value, + "Expected no work in destructor for TrieNode"); +} + +TrieSubtrie *TrieSubtrie::sink( + size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, + function_ref)> Saver) { + assert(NumSubtrieBits > 0); + std::unique_ptr S = create(StartBit + NumBits, NumSubtrieBits); + + assert(NewI < S->Slots.size()); + S->Slots[NewI].store(&Content); + + TrieNode *ExistingNode = &Content; + assert(I < Slots.size()); + if (Slots[I].compare_exchange_strong(ExistingNode, S.get())) + return Saver(std::move(S)); + + // Another thread created a subtrie already. Return it and let "S" be + // destructed. + return cast(ExistingNode); +} + +struct ThreadSafeTrieRawHashMapBase::ImplType { + static ImplType *create(size_t StartBit, size_t NumBits) { + size_t Size = sizeof(ImplType) + getTrieTailSize(StartBit, NumBits); + void *Memory = ::malloc(Size); + return ::new (Memory) ImplType(StartBit, NumBits); + } + + TrieSubtrie *save(std::unique_ptr S) { + assert(!S->Next && "Expected S to a freshly-constructed leaf"); + + TrieSubtrie *CurrentHead = nullptr; + // Add ownership of "S" to front of the list, so that Root -> S -> + // Root.Next. This works by repeatedly setting S->Next to a candidate value + // of Root.Next (initially nullptr), then setting Root.Next to S once the + // candidate matches reality. + while (!Root.Next.compare_exchange_weak(CurrentHead, S.get())) + S->Next.exchange(CurrentHead); + + // Ownership transferred to subtrie. + return S.release(); + } + + static void *operator new(size_t Size) { return ::malloc(Size); } + void operator delete(void *Ptr) { ::free(Ptr); } + + /// FIXME: This should take a function that allocates and constructs the + /// content lazily (taking the hash as a separate parameter), in case of + /// collision. + ThreadSafeAllocator ContentAlloc; + TrieSubtrie Root; // Must be last! Tail-allocated. + +private: + ImplType(size_t StartBit, size_t NumBits) : Root(StartBit, NumBits) {} +}; + +ThreadSafeTrieRawHashMapBase::ImplType & +ThreadSafeTrieRawHashMapBase::getOrCreateImpl() { + if (ImplType *Impl = ImplPtr.load()) + return *Impl; + + // Create a new ImplType and store it if another thread doesn't do so first. + // If another thread wins this one is destroyed locally. + std::unique_ptr Impl(ImplType::create(0, NumRootBits)); + ImplType *ExistingImpl = nullptr; + if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get())) + return *Impl.release(); + + return *ExistingImpl; +} + +ThreadSafeTrieRawHashMapBase::PointerBase +ThreadSafeTrieRawHashMapBase::find(ArrayRef Hash) const { + assert(!Hash.empty() && "Uninitialized hash"); + + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return PointerBase(); + + TrieSubtrie *S = &Impl->Root; + IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; + size_t Index = IndexGen.next(); + for (;;) { + // Try to set the content. + TrieNode *Existing = S->get(Index); + if (!Existing) + return PointerBase(S, Index, *IndexGen.StartBit); + + // Check for an exact match. + if (auto *ExistingContent = dyn_cast(Existing)) + return ExistingContent->getHash() == Hash + ? PointerBase(ExistingContent->getValuePointer()) + : PointerBase(S, Index, *IndexGen.StartBit); + + Index = IndexGen.next(); + S = cast(Existing); + } +} + +std::pair +ThreadSafeTrieRawHashMapBase::insert( + PointerBase Hint, ArrayRef Hash, + function_ref Hash)> + Constructor) { + assert(!Hash.empty() && "Uninitialized hash"); + + ImplType &Impl = getOrCreateImpl(); + TrieSubtrie *S = &Impl.Root; + IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; + size_t Index; + if (Hint.isHint()) { + S = static_cast(Hint.P); + Index = IndexGen.hint(Hint.I, Hint.B); + } else { + Index = IndexGen.next(); + } + + for (;;) { + // Load the node from the slot, allocating and calling the constructor if + // the slot is empty. + bool Generated = false; + TrieNode &Existing = S->Slots[Index].loadOrGenerate([&]() { + Generated = true; + + // Construct the value itself at the tail. + uint8_t *Memory = reinterpret_cast( + Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign)); + const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash); + + // Construct the TrieContent header, passing in the offset to the hash. + TrieContent *Content = ::new (Memory) + TrieContent(ContentOffset, Hash.size(), HashStorage - Memory); + assert(Hash == Content->getHash() && "Hash not properly initialized"); + return Content; + }); + // If we just generated it, return it! + if (Generated) + return {PointerBase(cast(Existing).getValuePointer()), true}; + + if (isa(Existing)) { + S = &cast(Existing); + Index = IndexGen.next(); + continue; + } + + // Return the existing content if it's an exact match! + auto &ExistingContent = cast(Existing); + if (ExistingContent.getHash() == Hash) + return {PointerBase(ExistingContent.getValuePointer()), false}; + + // Sink the existing content as long as the indexes match. + for (;;) { + size_t NextIndex = IndexGen.next(); + size_t NewIndexForExistingContent = + IndexGen.getCollidingBits(ExistingContent.getHash()); + S = S->sink(Index, ExistingContent, IndexGen.getNumBits(), + NewIndexForExistingContent, + [&Impl](std::unique_ptr S) { + return Impl.save(std::move(S)); + }); + Index = NextIndex; + + // Found the difference. + if (NextIndex != NewIndexForExistingContent) + break; + } + } +} + +static void printHexDigit(raw_ostream &OS, uint8_t Digit) { + if (Digit < 10) + OS << char(Digit + '0'); + else + OS << char(Digit - 10 + 'a'); +} + +static void printPrefix(raw_ostream &OS, StringRef Prefix) { + while (Prefix.size() >= 4) { + uint8_t Digit; + bool ErrorParsingBinary = Prefix.take_front(4).getAsInteger(2, Digit); + assert(!ErrorParsingBinary); + (void)ErrorParsingBinary; + printHexDigit(OS, Digit); + Prefix = Prefix.drop_front(4); + } + if (!Prefix.empty()) + OS << "[" << Prefix << "]"; +} + +ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase( + size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, + std::optional NumRootBits, std::optional NumSubtrieBits) + : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign), + ContentOffset(ContentOffset), + NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits), + NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits), + ImplPtr(nullptr) { + assert((!NumRootBits || *NumRootBits < 20) && + "Root should have fewer than ~1M slots"); + assert((!NumSubtrieBits || *NumSubtrieBits < 10) && + "Subtries should have fewer than ~1K slots"); +} + +ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase( + ThreadSafeTrieRawHashMapBase &&RHS) + : ContentAllocSize(RHS.ContentAllocSize), + ContentAllocAlign(RHS.ContentAllocAlign), + ContentOffset(RHS.ContentOffset), NumRootBits(RHS.NumRootBits), + NumSubtrieBits(RHS.NumSubtrieBits) { + // Steal the root from RHS. + ImplPtr = RHS.ImplPtr.exchange(nullptr); +} + +ThreadSafeTrieRawHashMapBase::~ThreadSafeTrieRawHashMapBase() { + assert(!ImplPtr.load() && "Expected subclass to call destroyImpl()"); +} + +void ThreadSafeTrieRawHashMapBase::destroyImpl( + function_ref Destructor) { + std::unique_ptr Impl(ImplPtr.exchange(nullptr)); + if (!Impl) + return; + + // Destroy content nodes throughout trie. Avoid destroying any subtries since + // we need TrieNode::classof() to find the content nodes. + // + // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them + // facilitate sparse iteration here. + if (Destructor) + for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load()) + for (auto &Slot : Trie->Slots) + if (auto *Content = dyn_cast_or_null(Slot.load())) + Destructor(Content->getValuePointer()); + + // Destroy the subtries. Incidentally, this destroys them in the reverse order + // of saving. + TrieSubtrie *Trie = Impl->Root.Next; + while (Trie) { + TrieSubtrie *Next = Trie->Next.exchange(nullptr); + delete Trie; + Trie = Next; + } +} + +ThreadSafeTrieRawHashMapBase::PointerBase +ThreadSafeTrieRawHashMapBase::getRoot() const { + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return PointerBase(); + return PointerBase(&Impl->Root); +} + +unsigned ThreadSafeTrieRawHashMapBase::getStartBit( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + if (auto *S = dyn_cast((TrieNode *)P.P)) + return S->StartBit; + return 0; +} + +unsigned ThreadSafeTrieRawHashMapBase::getNumBits( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + if (auto *S = dyn_cast((TrieNode *)P.P)) + return S->NumBits; + return 0; +} + +unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + auto *S = dyn_cast((TrieNode *)P.P); + if (!S) + return 0; + unsigned Num = 0; + for (unsigned I = 0, E = S->Slots.size(); I < E; ++I) + if (auto *E = S->Slots[I].load()) + ++Num; + return Num; +} + +std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return std::string(); + + auto *S = dyn_cast((TrieNode *)P.P); + if (!S || !S->IsSubtrie) + return std::string(); + + // Find a TrieContent node which has hash stored. Depth search following the + // first used slot until a TrieContent node is found. + TrieSubtrie *Current = S; + TrieContent *Node = nullptr; + while (Current) { + TrieSubtrie *Next = nullptr; + // find first used slot in the trie. + for (unsigned I = 0, E = Current->Slots.size(); I < E; ++I) { + auto *S = Current->get(I); + if (!S) + continue; + + if (auto *Content = dyn_cast(S)) + Node = Content; + else if (auto *Sub = dyn_cast(S)) + Next = Sub; + break; + } + + // Found the node. + if (Node) + break; + + // Continue to the next level if the node is not found. + Current = Next; + } + + assert(Node && "malformed trie, cannot find TrieContent on leaf node"); + // The prefix for the current trie is the first `StartBit` of the content + // stored underneath this subtrie. + std::string Bits; + for (unsigned I = 0, E = S->StartBit; I < E; ++I) { + unsigned Index = I / 8; + unsigned Offset = 7 - I % 8; + Bits.push_back('0' + ((Node->getHash()[Index] >> Offset) & 1)); + } + + std::string Str; + raw_string_ostream SS(Str); + printPrefix(SS, Bits); + return SS.str(); +} + +unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const { + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return 0; + unsigned Num = 0; + for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load()) + ++Num; + return Num; +} + +ThreadSafeTrieRawHashMapBase::PointerBase +ThreadSafeTrieRawHashMapBase::getNextTrie( + ThreadSafeTrieRawHashMapBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return PointerBase(); + auto *S = dyn_cast((TrieNode *)P.P); + if (!S) + return PointerBase(); + if (auto *E = S->Next.load()) + return PointerBase(E); + return PointerBase(); +} diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -79,6 +79,7 @@ StringSetTest.cpp StringSwitchTest.cpp TinyPtrVectorTest.cpp + TrieRawHashMapTest.cpp TwineTest.cpp TypeSwitchTest.cpp TypeTraitsTest.cpp diff --git a/llvm/unittests/ADT/TrieRawHashMapTest.cpp b/llvm/unittests/ADT/TrieRawHashMapTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/TrieRawHashMapTest.cpp @@ -0,0 +1,398 @@ +//===- TrieRawHashMapTest.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/TrieRawHashMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/SHA1.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace llvm { +class TrieRawHashMapTestHelper { +public: + TrieRawHashMapTestHelper() = default; + + void setTrie(ThreadSafeTrieRawHashMapBase *T) { Trie = T; } + + ThreadSafeTrieRawHashMapBase::PointerBase getRoot() const { + return Trie->getRoot(); + } + unsigned getStartBit(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getStartBit(P); + } + unsigned getNumBits(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getNumBits(P); + } + unsigned getNumSlotUsed(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getNumSlotUsed(P); + } + unsigned getNumTries() const { return Trie->getNumTries(); } + std::string + getTriePrefixAsString(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getTriePrefixAsString(P); + } + ThreadSafeTrieRawHashMapBase::PointerBase + getNextTrie(ThreadSafeTrieRawHashMapBase::PointerBase P) const { + return Trie->getNextTrie(P); + } + +private: + ThreadSafeTrieRawHashMapBase *Trie; +}; +} // namespace llvm + +namespace { +template +class SimpleTrieHashMapTest : public TrieRawHashMapTestHelper, + public ::testing::Test { +public: + using NumType = uint64_t; + using HashType = std::array; + using TrieType = ThreadSafeTrieRawHashMap; + + SimpleTrieHashMapTest() : Trie(RootBits, SubtrieBits) { + TrieRawHashMapTestHelper::setTrie(&Trie); + } + + // Use the number itself as hash to test the pathological case. + static HashType hash(NumType Num) { + NumType HashN = llvm::support::endian::byte_swap(Num, llvm::support::big); + HashType Hash; + memcpy(&Hash[0], &HashN, sizeof(HashType)); + return Hash; + }; + +protected: + TrieType Trie; +}; + +// Use root and subtrie sizes of 1 so this gets sunk quite deep. +using SmallNodeTrieTest = + SimpleTrieHashMapTest; +TEST_F(SmallNodeTrieTest, TrieAllocation) { + NumType Numbers[] = { + 0x0, std::numeric_limits::max(), 0x1, 0x2, + 0x3, std::numeric_limits::max() - 1u, + }; + + unsigned ExpectedTries[] = { + 1, // Allocate Root. + 1, // Both on the root. + 64, // 0 and 1 sinks all the way down. + 64, // no new allocation needed. + 65, // need a new node between 2 and 3. + 65 + 63, // 63 new allocation to sink two big numbers all the way. + }; + + const char *ExpectedPrefix[] = { + "", // Root. + "", // Root. + "000000000000000[000]", + "000000000000000[000]", + "000000000000000[001]", + "fffffffffffffff[111]", + }; + + for (unsigned I = 0; I < 6; ++I) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(Numbers[I])); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(Numbers[I]), Numbers[I])); + EXPECT_EQ(getNumTries(), ExpectedTries[I]); + EXPECT_EQ(getTriePrefixAsString(getNextTrie(getRoot())), ExpectedPrefix[I]); + } +} + +TEST_F(SmallNodeTrieTest, TrieStructure) { + NumType Numbers[] = { + // Three numbers that will nest deeply to test (1) sinking subtries and + // (2) deep, non-trivial hints. + std::numeric_limits::max(), + std::numeric_limits::max() - 2u, + std::numeric_limits::max() - 3u, + // One number to stay at the top-level. + 0x37, + }; + + for (NumType N : Numbers) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(N), N)); + } + for (NumType N : Numbers) { + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(hash(N), Lookup->Hash); + EXPECT_EQ(N, Lookup->Data); + + // Confirm a subsequent insertion fails to overwrite by trying to insert a + // bad value. + auto Result = Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1)); + EXPECT_FALSE(Result.second); + EXPECT_EQ(N, Result.first->Data); + } + + // Check the trie so we can confirm the structure is correct. Each subtrie + // should have 2 slots. The root's index=0 should have the content for + // 0x37 directly, and index=1 should be a linked-list of subtries, finally + // ending with content for (max-2) and (max-3). + // + // Note: This structure is not exhaustive (too expensive to update tests), + // but it does test that the dump format is somewhat readable and that the + // basic structure is correct. + // + // Note: This test requires that the trie reads bytes starting from index 0 + // of the array of uint8_t, and then reads each byte's bits from high to low. + + // Check the Trie. + // We should allocated a total of 64 SubTries for 64 bit hash. + ASSERT_EQ(getNumTries(), 64u); + // Check the root trie. Two slots and both are used. + ASSERT_EQ(getNumSlotUsed(getRoot()), 2u); + // Check last subtrie. + // Last allocated trie is the next node in the allocation chain. + auto LastAlloctedSubTrie = getNextTrie(getRoot()); + ASSERT_EQ(getTriePrefixAsString(LastAlloctedSubTrie), + "fffffffffffffff[110]"); + ASSERT_EQ(getStartBit(LastAlloctedSubTrie), 63u); + ASSERT_EQ(getNumBits(LastAlloctedSubTrie), 1u); + ASSERT_EQ(getNumSlotUsed(LastAlloctedSubTrie), 2u); +} + +// Use subtrie size of 5 to avoid hitting 64 evenly, making the final subtrie +// small. +using SmallFinalNodeTrieTest = + SimpleTrieHashMapTest; +TEST_F(SmallFinalNodeTrieTest, TrieStructureSmallFinalSubtrie) { + NumType Numbers[] = { + // Three numbers that will nest deeply to test (1) sinking subtries and + // (2) deep, non-trivial hints. + std::numeric_limits::max(), + std::numeric_limits::max() - 2u, + std::numeric_limits::max() - 3u, + // One number to stay at the top-level. + 0x37, + }; + + for (NumType N : Numbers) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(N), N)); + } + for (NumType N : Numbers) { + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(hash(N), Lookup->Hash); + EXPECT_EQ(N, Lookup->Data); + + // Confirm a subsequent insertion fails to overwrite by trying to insert a + // bad value. + auto Result = Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1)); + EXPECT_FALSE(Result.second); + EXPECT_EQ(N, Result.first->Data); + } + + // Check the trie so we can confirm the structure is correct. The root + // should have 2^8=256 slots, most subtries should have 2^5=32 slots, and the + // deepest subtrie should have 2^1=2 slots (since (64-8)mod(5)=1). + // should have 2 slots. The root's index=0 should have the content for + // 0x37 directly, and index=1 should be a linked-list of subtries, finally + // ending with content for (max-2) and (max-3). + // + // Note: This structure is not exhaustive (too expensive to update tests), + // but it does test that the dump format is somewhat readable and that the + // basic structure is correct. + // + // Note: This test requires that the trie reads bytes starting from index 0 + // of the array of uint8_t, and then reads each byte's bits from high to low. + + // Check the Trie. + // 64 bit hash = 8 + 5 * 11 + 1, so 1 root, 11 8bit subtrie and 1 last level + // subtrie, 13 total. + ASSERT_EQ(getNumTries(), 13u); + // Check the root trie. Two slots and both are used. + ASSERT_EQ(getNumSlotUsed(getRoot()), 2u); + // Check last subtrie. + // Last allocated trie is the next node in the allocation chain. + auto LastAlloctedSubTrie = getNextTrie(getRoot()); + ASSERT_EQ(getTriePrefixAsString(LastAlloctedSubTrie), + "fffffffffffffff[110]"); + ASSERT_EQ(getStartBit(LastAlloctedSubTrie), 63u); + ASSERT_EQ(getNumBits(LastAlloctedSubTrie), 1u); + ASSERT_EQ(getNumSlotUsed(LastAlloctedSubTrie), 2u); +} + +TEST(TrieRawHashMapTest, TrieDestructionLoop) { + // Test destroying large Trie. Make sure there is no recursion that can + // overflow the stack. + using NumT = uint64_t; + struct NumWithDestructorT { + NumT Num; + operator NumT() const { return Num; } + ~NumWithDestructorT() {} + }; + + using HashT = std::array; + using TrieT = ThreadSafeTrieRawHashMap; + using TrieWithDestructorT = + ThreadSafeTrieRawHashMap; + + // Use the number itself in big-endian order as the hash. + auto hash = [](NumT Num) { + NumT HashN = llvm::support::endian::byte_swap(Num, llvm::support::big); + HashT Hash; + memcpy(&Hash[0], &HashN, sizeof(HashT)); + return Hash; + }; + + // Use optionals to control when destructors are called. + std::optional Trie; + std::optional TrieWithDestructor; + + // Limit the tries to 2 slots (1 bit) to generate subtries at a higher rate. + Trie.emplace(/*NumRootBits=*/1, /*NumSubtrieBits=*/1); + TrieWithDestructor.emplace(/*NumRootBits=*/1, /*NumSubtrieBits=*/1); + + // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug + // builds. + static constexpr uint64_t MaxN = 100000; + for (uint64_t N = 0; N != MaxN; ++N) { + HashT Hash = hash(N); + Trie->insert(TrieT::pointer(), TrieT::value_type(Hash, N)); + TrieWithDestructor->insert( + TrieWithDestructorT::pointer(), + TrieWithDestructorT::value_type(Hash, NumWithDestructorT{N})); + } + + // Destroy tries. If destruction is recursive and MaxN is high enough, these + // will both fail. + Trie.reset(); + TrieWithDestructor.reset(); +} + +namespace { +static constexpr unsigned HashSize = 20; +using HashType = std::array; + +class MockTrieStringSet + : ThreadSafeTrieRawHashMap { +public: + using TrieType = typename MockTrieStringSet::ThreadSafeTrieRawHashMap; + using LazyValueConstructor = typename MockTrieStringSet:: + ThreadSafeTrieRawHashMap::LazyValueConstructor; + + class pointer : public TrieType::const_pointer { + using BaseType = typename TrieType::const_pointer; + + public: + const std::string &operator*() const { + return TrieType::const_pointer::operator*().Data; + } + const std::string *operator->() const { return &operator*(); } + + pointer() = default; + pointer(pointer &&) = default; + pointer(const pointer &) = default; + pointer &operator=(pointer &&) = default; + pointer &operator=(const pointer &) = default; + + private: + pointer(BaseType Result) : BaseType(Result) {} + friend class MockTrieStringSet; + }; + + MockTrieStringSet(std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt) + : TrieType(NumRootBits, NumSubtrieBits) {} + + static HashType hash(const std::string &V) { + // Mock hash function. Create unique hash for the test case inputs below. + const unsigned MaxSize = 4; + assert(V.size() <= MaxSize && "Invalid Input"); + std::string Input = V; + Input.resize(MaxSize); + HashType Hash; + // Repeat the padded string to hash length so SubTries will be created with + // different test configuration. + for (unsigned I = 0; I < HashSize; I += MaxSize) + llvm::copy(Input, Hash.begin() + I); + return Hash; + } + + pointer find(const std::string &Value) const { + return pointer(TrieType::find(hash(Value))); + } + std::pair insert(pointer Hint, std::string &&Value) { + auto Result = TrieType::insertLazy( + typename pointer::BaseType(Hint), hash(Value), + [&](LazyValueConstructor C) { C(std::move(Value)); }); + return {pointer(Result.first), Result.second}; + } + std::pair insert(pointer Hint, const std::string &Value) { + auto Result = + TrieType::insertLazy(typename pointer::BaseType(Hint), hash(Value), + [&](LazyValueConstructor C) { C(Value); }); + return {pointer(Result.first), Result.second}; + } + std::pair insert(std::string &&Value) { + return insert(pointer(), Value); + } + std::pair insert(const std::string &Value) { + return insert(pointer(), Value); + } +}; +} // end anonymous namespace + +TEST(TrieRawHashMapTest, Strings) { + for (unsigned RootBits : {2, 3, 6, 10}) { + for (unsigned SubtrieBits : {2, 3, 4}) { + MockTrieStringSet Strings(RootBits, SubtrieBits); + auto A1 = Strings.insert("A"); + EXPECT_EQ(&*A1.first, &*Strings.insert("A").first); + std::string A2 = *A1.first; + EXPECT_EQ(&*A1.first, &*Strings.insert(A2).first); + + auto B1 = Strings.insert("B"); + EXPECT_EQ(&*B1.first, &*Strings.insert(*B1.first).first); + std::string B2 = *B1.first; + EXPECT_EQ(&*B1.first, &*Strings.insert(B2).first); + + for (int I = 0, E = 1000; I != E; ++I) { + MockTrieStringSet::pointer Lookup; + std::string S = Twine(I).str(); + if (I & 1) + Lookup = Strings.find(S); + auto S1 = Strings.insert(Lookup, S); + EXPECT_EQ(&*S1.first, &*Strings.insert(*S1.first).first); + std::string S2 = *S1.first; + EXPECT_EQ(&*S1.first, &*Strings.insert(S2).first); + } + for (int I = 0, E = 1000; I != E; ++I) { + std::string S = Twine(I).str(); + MockTrieStringSet::pointer Lookup = Strings.find(S); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(S, *Lookup); + } + } + } +} + +} // namespace