diff --git a/llvm/include/llvm/ADT/HashMappedTrie.h b/llvm/include/llvm/ADT/HashMappedTrie.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/HashMappedTrie.h @@ -0,0 +1,416 @@ +//===- HashMappedTrie.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_HASHMAPPEDTRIE_H +#define LLVM_ADT_HASHMAPPEDTRIE_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include +#include + +namespace llvm { + +class raw_ostream; + +/// HashMappedTrie - is a lock-free thread-safe trie that is can be used to +/// store/index data based on a hash value. It can be customized to work with +/// any hash algorithm or store any data. +/// +/// Data structure: +/// Data node stored in the Trie contains both hash and data: +/// struct { +/// HashT Hash; +/// DataT Data; +/// }; +/// +/// Data is stored/indexed via a prefix tree, where each node in the tree can be +/// either the root, a sub-trie or a data node. Assuming a 4-bit hash and two +/// data objects {0001, A} and {0100, B}, it can be stored in a trie +/// (assuming Root has 2 bits, SubTrie has 1 bit): +/// +--------+ +/// |Root[00]| -> {0001, A} +/// | [01]| -> {0100, B} +/// | [10]| (empty) +/// | [11]| (empty) +/// +--------+ +/// +/// Inserting a new object {0010, C} will result in: +/// +--------+ +----------+ +/// |Root[00]| -> |SubTrie[0]| -> {0001, A} +/// | | | [1]| -> {0010, C} +/// | | +----------+ +/// | [01]| -> {0100, B} +/// | [10]| (empty) +/// | [11]| (empty) +/// +--------+ +/// Note object A is sinked down to a sub-trie during the insertion. All the +/// nodes are inserted through compare-exchange to ensure thread-safe and +/// lock-free. +/// +/// To find an object in the trie, walk the tree with prefix of the hash until +/// the data node is found. Then the hash is compared with the hash stored in +/// the data node to see if the is the same object. +/// +/// Hash collision is not allowed so it is recommanded to use trie with a +/// "strong" hashing algorithm. A well-distributed hash can also result in +/// better performance and memory usage. +/// +/// It currently does not support iteration and deletion. + +/// Base class for a lock-free thread-safe hash-mapped trie. +class ThreadSafeHashMappedTrieBase { +public: + static constexpr size_t TrieContentBaseSize = 4; + static constexpr size_t DefaultNumRootBits = 6; + static constexpr size_t DefaultNumSubtrieBits = 4; + +private: + template struct AllocValueType { + char Base[TrieContentBaseSize]; + std::aligned_union_t Content; + }; + +protected: + template + static constexpr size_t DefaultContentAllocSize = sizeof(AllocValueType); + + template + static constexpr size_t DefaultContentAllocAlign = alignof(AllocValueType); + + template + static constexpr size_t DefaultContentOffset = + offsetof(AllocValueType, Content); + +public: + void operator delete(void *Ptr) { ::free(Ptr); } + + LLVM_DUMP_METHOD void dump() const; + void print(raw_ostream &OS) const; + +protected: + /// Result of a lookup. Suitable for an insertion hint. Maybe could be + /// expanded into an iterator of sorts, but likely not useful (visiting + /// everything in the trie should probably be done some way other than + /// through an iterator pattern). + class PointerBase { + protected: + void *get() const { return I == -2u ? P : nullptr; } + + public: + PointerBase() noexcept = default; + PointerBase(PointerBase &&) = default; + PointerBase(const PointerBase &) = default; + PointerBase &operator=(PointerBase &&) = default; + PointerBase &operator=(const PointerBase &) = default; + + private: + friend class ThreadSafeHashMappedTrieBase; + explicit PointerBase(void *Content) : P(Content), I(-2u) {} + PointerBase(void *P, unsigned I, unsigned B) : P(P), I(I), B(B) {} + + bool isHint() const { return I != -1u && I != -2u; } + + void *P = nullptr; + unsigned I = -1u; + unsigned B = 0; + }; + + PointerBase find(ArrayRef Hash) const; + + /// Insert and return the stored content. + PointerBase + insert(PointerBase Hint, ArrayRef Hash, + function_ref Hash)> + Constructor); + + ThreadSafeHashMappedTrieBase() = delete; + + ThreadSafeHashMappedTrieBase( + size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, + std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt); + + /// Destructor, which asserts if there's anything to do. Subclasses should + /// call \a destroyImpl(). + /// + /// \pre \a destroyImpl() was already called. + ~ThreadSafeHashMappedTrieBase(); + void destroyImpl(function_ref Destructor); + + ThreadSafeHashMappedTrieBase(ThreadSafeHashMappedTrieBase &&RHS); + + // Move assignment can be implemented in a thread-safe way if NumRootBits and + // NumSubtrieBits are stored inside the Root. + ThreadSafeHashMappedTrieBase & + operator=(ThreadSafeHashMappedTrieBase &&RHS) = delete; + + // No copy. + ThreadSafeHashMappedTrieBase(const ThreadSafeHashMappedTrieBase &) = delete; + ThreadSafeHashMappedTrieBase & + operator=(const ThreadSafeHashMappedTrieBase &) = delete; + + // Debug functions. Implementation details and not guaranteed to be + // thread-safe. + PointerBase getRoot() const; + unsigned getStartBit(PointerBase P) const; + unsigned getNumBits(PointerBase P) const; + unsigned getNumSlotUsed(PointerBase P) const; + unsigned getNumTries() const; + // Visit next trie in the allocation chain. + PointerBase getNextTrie(PointerBase P) const; + +private: + const unsigned short ContentAllocSize; + const unsigned short ContentAllocAlign; + const unsigned short ContentOffset; + unsigned short NumRootBits; + unsigned short NumSubtrieBits; + struct ImplType; + // ImplPtr is owned by ThreadSafeHashMappedTrieBase and needs to be freed in + // destoryImpl. + std::atomic ImplPtr; + ImplType &getOrCreateImpl(); + ImplType *getImpl() const; +}; + +/// Lock-free thread-safe hash-mapped trie. +template +class ThreadSafeHashMappedTrie : ThreadSafeHashMappedTrieBase { +public: + using HashT = std::array; + + class LazyValueConstructor; + struct value_type { + const HashT Hash; + T Data; + + value_type(value_type &&) = default; + value_type(const value_type &) = default; + + value_type(ArrayRef Hash, const T &Data) + : Hash(copyHash(Hash)), Data(Data) {} + value_type(ArrayRef Hash, T &&Data) + : Hash(copyHash(Hash)), Data(std::move(Data)) {} + + private: + friend class LazyValueConstructor; + + struct EmplaceTag {}; + template + value_type(ArrayRef Hash, EmplaceTag, ArgsT &&...Args) + : Hash(copyHash(Hash)), Data(std::forward(Args)...) {} + + static HashT copyHash(ArrayRef HashRef) { + HashT Hash; + std::copy(HashRef.begin(), HashRef.end(), Hash.data()); + return Hash; + } + }; + + using ThreadSafeHashMappedTrieBase::operator delete; + using HashType = HashT; + + using ThreadSafeHashMappedTrieBase::dump; + using ThreadSafeHashMappedTrieBase::print; + +private: + template class PointerImpl : PointerBase { + friend class ThreadSafeHashMappedTrie; + + ValueT *get() const { + if (void *B = PointerBase::get()) + return reinterpret_cast(B); + return nullptr; + } + + public: + ValueT &operator*() const { + assert(get()); + return *get(); + } + ValueT *operator->() const { + assert(get()); + return get(); + } + explicit operator bool() const { return get(); } + + PointerImpl() = default; + PointerImpl(PointerImpl &&) = default; + PointerImpl(const PointerImpl &) = default; + PointerImpl &operator=(PointerImpl &&) = default; + PointerImpl &operator=(const PointerImpl &) = default; + + protected: + PointerImpl(PointerBase Result) : PointerBase(Result) {} + }; + +public: + class pointer; + class const_pointer; + class pointer : public PointerImpl { + friend class ThreadSafeHashMappedTrie; + friend class const_pointer; + + public: + pointer() = default; + pointer(pointer &&) = default; + pointer(const pointer &) = default; + pointer &operator=(pointer &&) = default; + pointer &operator=(const pointer &) = default; + + private: + pointer(PointerBase Result) : pointer::PointerImpl(Result) {} + }; + + class const_pointer : public PointerImpl { + friend class ThreadSafeHashMappedTrie; + + public: + const_pointer() = default; + const_pointer(const_pointer &&) = default; + const_pointer(const const_pointer &) = default; + const_pointer &operator=(const_pointer &&) = default; + const_pointer &operator=(const const_pointer &) = default; + + const_pointer(const pointer &P) : const_pointer::PointerImpl(P) {} + + private: + const_pointer(PointerBase Result) : const_pointer::PointerImpl(Result) {} + }; + + class LazyValueConstructor { + public: + value_type &operator()(T &&RHS) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) value_type(Hash, std::move(RHS))); + } + value_type &operator()(const T &RHS) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) value_type(Hash, RHS)); + } + template value_type &emplace(ArgsT &&...Args) { + assert(Mem && "Constructor already called, or moved away"); + return assign(::new (Mem) + value_type(Hash, typename value_type::EmplaceTag{}, + std::forward(Args)...)); + } + + LazyValueConstructor(LazyValueConstructor &&RHS) + : Mem(RHS.Mem), Result(RHS.Result), Hash(RHS.Hash) { + RHS.Mem = nullptr; // Moved away, cannot call. + } + ~LazyValueConstructor() { assert(!Mem && "Constructor never called!"); } + + private: + value_type &assign(value_type *V) { + Mem = nullptr; + Result = V; + return *V; + } + friend class ThreadSafeHashMappedTrie; + LazyValueConstructor() = delete; + LazyValueConstructor(void *Mem, value_type *&Result, ArrayRef Hash) + : Mem(Mem), Result(Result), Hash(Hash) { + assert(Hash.size() == sizeof(HashT) && "Invalid hash"); + assert(Mem && "Invalid memory for construction"); + } + void *Mem; + value_type *&Result; + ArrayRef Hash; + }; + + /// Insert with a hint. Default-constructed hint will work, but it's + /// recommended to start with a lookup to avoid overhead in object creation + /// if it already exists. + pointer insertLazy(const_pointer Hint, ArrayRef Hash, + function_ref OnConstruct) { + return pointer(ThreadSafeHashMappedTrieBase::insert( + Hint, Hash, [&](void *Mem, ArrayRef Hash) { + value_type *Result = nullptr; + OnConstruct(LazyValueConstructor(Mem, Result, Hash)); + return Result->Hash.data(); + })); + } + + pointer insertLazy(ArrayRef Hash, + function_ref OnConstruct) { + return insertLazy(const_pointer(), Hash, OnConstruct); + } + + pointer insert(const_pointer Hint, value_type &&HashedData) { + return insertLazy(Hint, HashedData.Hash, [&](LazyValueConstructor C) { + C(std::move(HashedData.Data)); + }); + } + + pointer insert(const_pointer Hint, const value_type &HashedData) { + return insertLazy(Hint, HashedData.Hash, + [&](LazyValueConstructor C) { C(HashedData.Data); }); + } + + pointer find(ArrayRef Hash) { + assert(Hash.size() == std::tuple_size::value); + return ThreadSafeHashMappedTrieBase::find(Hash); + } + + const_pointer find(ArrayRef Hash) const { + assert(Hash.size() == std::tuple_size::value); + return ThreadSafeHashMappedTrieBase::find(Hash); + } + + ThreadSafeHashMappedTrie(std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt) + : ThreadSafeHashMappedTrieBase(DefaultContentAllocSize, + DefaultContentAllocAlign, + DefaultContentOffset, + NumRootBits, NumSubtrieBits) {} + + ~ThreadSafeHashMappedTrie() { + if constexpr (std::is_trivially_destructible::value) + this->destroyImpl(nullptr); + else + this->destroyImpl( + [](void *P) { static_cast(P)->~value_type(); }); + } + + // Move constructor okay. + ThreadSafeHashMappedTrie(ThreadSafeHashMappedTrie &&) = default; + + // No move assignment or any copy. + ThreadSafeHashMappedTrie &operator=(ThreadSafeHashMappedTrie &&) = delete; + ThreadSafeHashMappedTrie(const ThreadSafeHashMappedTrie &) = delete; + ThreadSafeHashMappedTrie & + operator=(const ThreadSafeHashMappedTrie &) = delete; + + // Debug functions. Implementation details and not guaranteed to be + // thread-safe. + const_pointer getRoot() const { + return ThreadSafeHashMappedTrieBase::getRoot(); + } + unsigned getStartBit(const_pointer P) const { + return ThreadSafeHashMappedTrieBase::getStartBit(P); + } + unsigned getNumBits(const_pointer P) const { + return ThreadSafeHashMappedTrieBase::getNumBits(P); + } + unsigned getNumSlotUsed(const_pointer P) const { + return ThreadSafeHashMappedTrieBase::getNumSlotUsed(P); + } + unsigned getNumTries() const { + return ThreadSafeHashMappedTrieBase::getNumTries(); + } + const_pointer getNextTrie(const_pointer P) const { + return ThreadSafeHashMappedTrieBase::getNextTrie(P); + } +}; + +} // namespace llvm + +#endif // LLVM_ADT_HASHMAPPEDTRIE_H diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -171,6 +171,7 @@ FormatVariadic.cpp GlobPattern.cpp GraphWriter.cpp + HashMappedTrie.cpp Hashing.cpp InitLLVM.cpp InstructionCost.cpp diff --git a/llvm/lib/Support/HashMappedTrie.cpp b/llvm/lib/Support/HashMappedTrie.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Support/HashMappedTrie.cpp @@ -0,0 +1,545 @@ +//===- HashMappedTrie.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/HashMappedTrie.h" +#include "HashMappedTrieIndexGenerator.h" +#include "llvm/ADT/LazyAtomicPointer.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ThreadSafeAllocator.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace llvm; + +namespace { +struct TrieNode { + const bool IsSubtrie = false; + + TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {} + + static void *operator new(size_t Size) { return ::malloc(Size); } + void operator delete(void *Ptr) { ::free(Ptr); } +}; + +struct TrieContent final : public TrieNode { + const uint8_t ContentOffset; + const uint8_t HashSize; + const uint8_t HashOffset; + + void *getValuePointer() const { + auto Content = reinterpret_cast(this) + ContentOffset; + return const_cast(Content); + } + + ArrayRef getHash() const { + auto *Begin = reinterpret_cast(this) + HashOffset; + return makeArrayRef(Begin, Begin + HashSize); + } + + TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset) + : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset), + HashSize(HashSize), HashOffset(HashOffset) {} +}; +static_assert(sizeof(TrieContent) == + ThreadSafeHashMappedTrieBase::TrieContentBaseSize, + "Check header assumption!"); + +class TrieSubtrie final : public TrieNode { +public: + TrieNode *get(size_t I) const { return Slots[I].load(); } + + TrieSubtrie * + sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, + function_ref)> Saver); + + void printHash(raw_ostream &OS, ArrayRef Bytes) const; + void print(raw_ostream &OS) const { print(OS, std::nullopt); } + void print(raw_ostream &OS, std::optional Prefix) const; + void dump() const { print(dbgs()); } + + static std::unique_ptr create(size_t StartBit, size_t NumBits); + + explicit TrieSubtrie(size_t StartBit, size_t NumBits); + +private: + // FIXME: Use a bitset to speed up access: + // + // std::array, NumSlots/64> IsSet; + // + // This will avoid needing to visit sparsely filled slots in + // \a ThreadSafeHashMappedTrieBase::destroyImpl() when there's a non-trivial + // destructor. + // + // It would also greatly speed up iteration, if we add that some day, and + // allow get() to return one level sooner. + // + // This would be the algorithm for updating IsSet (after updating Slots): + // + // std::atomic &Bits = IsSet[I.High]; + // const uint64_t NewBit = 1ULL << I.Low; + // uint64_t Old = 0; + // while (!Bits.compare_exchange_weak(Old, Old | NewBit)) + // ; + + // For debugging. + unsigned StartBit = 0; + unsigned NumBits = 0; + friend class llvm::ThreadSafeHashMappedTrieBase; + +public: + /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie. + std::atomic Next; + + /// The (co-allocated) slots of the subtrie. + MutableArrayRef> Slots; +}; +} // end namespace + +namespace llvm { +template <> struct isa_impl { + static inline bool doit(const TrieNode &TN) { return !TN.IsSubtrie; } +}; +template <> struct isa_impl { + static inline bool doit(const TrieNode &TN) { return TN.IsSubtrie; } +}; +} // end namespace llvm + +static size_t getTrieTailSize(size_t StartBit, size_t NumBits) { + assert(NumBits < 20 && "Tries should have fewer than ~1M slots"); + return sizeof(TrieNode *) * (1u << NumBits); +} + +std::unique_ptr TrieSubtrie::create(size_t StartBit, + size_t NumBits) { + size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits); + void *Memory = ::malloc(Size); + TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits); + return std::unique_ptr(S); +} + +TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits) + : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Next(nullptr), + Slots(reinterpret_cast *>( + reinterpret_cast(this) + sizeof(TrieSubtrie)), + (1u << NumBits)) { + for (auto *I = Slots.begin(), *E = Slots.end(); I != E; ++I) + new (I) LazyAtomicPointer(nullptr); + + static_assert( + std::is_trivially_destructible>::value, + "Expected no work in destructor for TrieNode"); +} + +TrieSubtrie *TrieSubtrie::sink( + size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, + function_ref)> Saver) { + assert(NumSubtrieBits > 0); + std::unique_ptr S = create(StartBit + NumBits, NumSubtrieBits); + + assert(NewI < S->Slots.size()); + S->Slots[NewI].store(&Content); + + TrieNode *ExistingNode = &Content; + assert(I < Slots.size()); + if (Slots[I].compare_exchange_strong(ExistingNode, S.get())) + return Saver(std::move(S)); + + // Another thread created a subtrie already. Return it and let "S" be + // destructed. + return cast(ExistingNode); +} + +struct ThreadSafeHashMappedTrieBase::ImplType { + static ImplType *create(size_t StartBit, size_t NumBits) { + size_t Size = sizeof(ImplType) + getTrieTailSize(StartBit, NumBits); + void *Memory = ::malloc(Size); + return ::new (Memory) ImplType(StartBit, NumBits); + } + + TrieSubtrie *save(std::unique_ptr S) { + assert(!S->Next && "Expected S to a freshly-constructed leaf"); + + TrieSubtrie *CurrentHead = nullptr; + // Add ownership of "S" to front of the list, so that Root -> S -> + // Root.Next. This works by repeatedly setting S->Next to a candidate value + // of Root.Next (initially nullptr), then setting Root.Next to S once the + // candidate matches reality. + while (!Root.Next.compare_exchange_weak(CurrentHead, S.get())) + S->Next.exchange(CurrentHead); + + // Ownership transferred to subtrie. + return S.release(); + } + + static void *operator new(size_t Size) { return ::malloc(Size); } + void operator delete(void *Ptr) { ::free(Ptr); } + + /// FIXME: This should take a function that allocates and constructs the + /// content lazily (taking the hash as a separate parameter), in case of + /// collision. + ThreadSafeAllocator ContentAlloc; + TrieSubtrie Root; // Must be last! Tail-allocated. + +private: + ImplType(size_t StartBit, size_t NumBits) : Root(StartBit, NumBits) {} +}; + +ThreadSafeHashMappedTrieBase::ImplType & +ThreadSafeHashMappedTrieBase::getOrCreateImpl() { + if (ImplType *Impl = ImplPtr.load()) + return *Impl; + + // Create a new ImplType and store it if another thread doesn't do so first. + // If another thread wins this one is destroyed locally. + std::unique_ptr Impl(ImplType::create(0, NumRootBits)); + ImplType *ExistingImpl = nullptr; + if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get())) + return *Impl.release(); + + return *ExistingImpl; +} + +ThreadSafeHashMappedTrieBase::PointerBase +ThreadSafeHashMappedTrieBase::find(ArrayRef Hash) const { + assert(!Hash.empty() && "Uninitialized hash"); + + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return PointerBase(); + + TrieSubtrie *S = &Impl->Root; + IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; + size_t Index = IndexGen.next(); + for (;;) { + // Try to set the content. + TrieNode *Existing = S->get(Index); + if (!Existing) + return PointerBase(S, Index, *IndexGen.StartBit); + + // Check for an exact match. + if (auto *ExistingContent = dyn_cast(Existing)) + return ExistingContent->getHash() == Hash + ? PointerBase(ExistingContent->getValuePointer()) + : PointerBase(S, Index, *IndexGen.StartBit); + + Index = IndexGen.next(); + S = cast(Existing); + } +} + +ThreadSafeHashMappedTrieBase::PointerBase ThreadSafeHashMappedTrieBase::insert( + PointerBase Hint, ArrayRef Hash, + function_ref Hash)> + Constructor) { + assert(!Hash.empty() && "Uninitialized hash"); + + ImplType &Impl = getOrCreateImpl(); + TrieSubtrie *S = &Impl.Root; + IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; + size_t Index; + if (Hint.isHint()) { + S = static_cast(Hint.P); + Index = IndexGen.hint(Hint.I, Hint.B); + } else { + Index = IndexGen.next(); + } + + for (;;) { + // Load the node from the slot, allocating and calling the constructor if + // the slot is empty. + bool Generated = false; + TrieNode &Existing = S->Slots[Index].loadOrGenerate([&]() { + Generated = true; + + // Construct the value itself at the tail. + uint8_t *Memory = reinterpret_cast( + Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign)); + const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash); + + // Construct the TrieContent header, passing in the offset to the hash. + TrieContent *Content = ::new (Memory) + TrieContent(ContentOffset, Hash.size(), HashStorage - Memory); + assert(Hash == Content->getHash() && "Hash not properly initialized"); + return Content; + }); + // If we just generated it, return it! + if (Generated) + return PointerBase(cast(Existing).getValuePointer()); + + if (isa(Existing)) { + S = &cast(Existing); + Index = IndexGen.next(); + continue; + } + + // Return the existing content if it's an exact match! + auto &ExistingContent = cast(Existing); + if (ExistingContent.getHash() == Hash) + return PointerBase(ExistingContent.getValuePointer()); + + // Sink the existing content as long as the indexes match. + for (;;) { + size_t NextIndex = IndexGen.next(); + size_t NewIndexForExistingContent = + IndexGen.getCollidingBits(ExistingContent.getHash()); + S = S->sink(Index, ExistingContent, IndexGen.getNumBits(), + NewIndexForExistingContent, + [&Impl](std::unique_ptr S) { + return Impl.save(std::move(S)); + }); + Index = NextIndex; + + // Found the difference. + if (NextIndex != NewIndexForExistingContent) + break; + } + } +} + +static void printHexDigit(raw_ostream &OS, uint8_t Digit) { + if (Digit < 10) + OS << char(Digit + '0'); + else + OS << char(Digit - 10 + 'a'); +} + +static void printHexDigits(raw_ostream &OS, ArrayRef Bytes, + size_t StartBit, size_t NumBits) { + assert(StartBit % 4 == 0); + assert(NumBits % 4 == 0); + for (size_t I = StartBit, E = StartBit + NumBits; I != E; I += 4) { + uint8_t HexPair = Bytes[I / 8]; + uint8_t HexDigit = I % 8 == 0 ? HexPair >> 4 : HexPair & 0xf; + printHexDigit(OS, HexDigit); + } +} + +static void printBits(raw_ostream &OS, ArrayRef Bytes, size_t StartBit, + size_t NumBits) { + assert(StartBit + NumBits <= Bytes.size() * 8u); + for (size_t I = StartBit, E = StartBit + NumBits; I != E; ++I) { + uint8_t Byte = Bytes[I / 8]; + size_t ByteOffset = I % 8; + if (size_t ByteShift = 8 - ByteOffset - 1) + Byte >>= ByteShift; + OS << (Byte & 0x1 ? '1' : '0'); + } +} + +void TrieSubtrie::printHash(raw_ostream &OS, ArrayRef Bytes) const { + // afb[1c:00*01110*0]def + size_t EndBit = StartBit + NumBits; + size_t HashEndBit = Bytes.size() * 8u; + + size_t FirstBinaryBit = StartBit & ~0x3u; + printHexDigits(OS, Bytes, 0, FirstBinaryBit); + + size_t LastBinaryBit = (EndBit + 3u) & ~0x3u; + OS << "["; + printBits(OS, Bytes, FirstBinaryBit, LastBinaryBit - FirstBinaryBit); + OS << "]"; + + printHexDigits(OS, Bytes, LastBinaryBit, HashEndBit - LastBinaryBit); +} + +static void appendIndexBits(std::string &Prefix, size_t Index, + size_t NumSlots) { + std::string Bits; + for (size_t NumBits = 1u; NumBits < NumSlots; NumBits <<= 1) { + Bits.push_back('0' + (Index & 0x1)); + Index >>= 1; + } + for (char Ch : llvm::reverse(Bits)) + Prefix += Ch; +} + +static void printPrefix(raw_ostream &OS, StringRef Prefix) { + while (Prefix.size() >= 4) { + uint8_t Digit; + bool ErrorParsingBinary = Prefix.take_front(4).getAsInteger(2, Digit); + assert(!ErrorParsingBinary); + (void)ErrorParsingBinary; + printHexDigit(OS, Digit); + Prefix = Prefix.drop_front(4); + } + if (!Prefix.empty()) + OS << "[" << Prefix << "]"; +} + +void TrieSubtrie::print(raw_ostream &OS, + std::optional Prefix) const { + if (!Prefix) { + OS << "root"; + Prefix.emplace(); + } else { + OS << "subtrie="; + printPrefix(OS, *Prefix); + } + + OS << " num-slots=" << Slots.size() << "\n"; + SmallVector Subs; + SmallVector Prefixes; + for (size_t I = 0, E = Slots.size(); I != E; ++I) { + TrieNode *N = get(I); + if (!N) + continue; + OS << "- index=" << I << " "; + if (auto *S = dyn_cast(N)) { + std::string SubtriePrefix = *Prefix; + appendIndexBits(SubtriePrefix, I, Slots.size()); + OS << "subtrie="; + printPrefix(OS, SubtriePrefix); + OS << "\n"; + Subs.push_back(S); + Prefixes.push_back(SubtriePrefix); + continue; + } + auto *Content = cast(N); + OS << "content="; + printHash(OS, Content->getHash()); + OS << "\n"; + } + for (size_t I = 0, E = Subs.size(); I != E; ++I) + Subs[I]->print(OS, Prefixes[I]); +} + +void ThreadSafeHashMappedTrieBase::print(raw_ostream &OS) const { + OS << "root-bits=" << NumRootBits << " subtrie-bits=" << NumSubtrieBits + << "\n"; + if (ImplType *Impl = ImplPtr.load()) + Impl->Root.print(OS); + else + OS << "[no-root]\n"; +} + +LLVM_DUMP_METHOD void ThreadSafeHashMappedTrieBase::dump() const { + print(dbgs()); +} + +ThreadSafeHashMappedTrieBase::ThreadSafeHashMappedTrieBase( + size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, + std::optional NumRootBits, std::optional NumSubtrieBits) + : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign), + ContentOffset(ContentOffset), + NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits), + NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits), + ImplPtr(nullptr) { + assert((!NumRootBits || *NumRootBits < 20) && + "Root should have fewer than ~1M slots"); + assert((!NumSubtrieBits || *NumSubtrieBits < 10) && + "Subtries should have fewer than ~1K slots"); +} + +ThreadSafeHashMappedTrieBase::ThreadSafeHashMappedTrieBase( + ThreadSafeHashMappedTrieBase &&RHS) + : ContentAllocSize(RHS.ContentAllocSize), + ContentAllocAlign(RHS.ContentAllocAlign), + ContentOffset(RHS.ContentOffset), NumRootBits(RHS.NumRootBits), + NumSubtrieBits(RHS.NumSubtrieBits) { + // Steal the root from RHS. + ImplPtr = RHS.ImplPtr.exchange(nullptr); +} + +ThreadSafeHashMappedTrieBase::~ThreadSafeHashMappedTrieBase() { + assert(!ImplPtr.load() && "Expected subclass to call destroyImpl()"); +} + +void ThreadSafeHashMappedTrieBase::destroyImpl( + function_ref Destructor) { + std::unique_ptr Impl(ImplPtr.exchange(nullptr)); + if (!Impl) + return; + + // Destroy content nodes throughout trie. Avoid destroying any subtries since + // we need TrieNode::classof() to find the content nodes. + // + // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them + // facilitate sparse iteration here. + if (Destructor) + for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load()) + for (auto &Slot : Trie->Slots) + if (auto *Content = dyn_cast_or_null(Slot.load())) + Destructor(Content->getValuePointer()); + + // Destroy the subtries. Incidentally, this destroys them in the reverse order + // of saving. + TrieSubtrie *Trie = Impl->Root.Next; + while (Trie) { + TrieSubtrie *Next = Trie->Next.exchange(nullptr); + delete Trie; + Trie = Next; + } +} + +ThreadSafeHashMappedTrieBase::PointerBase +ThreadSafeHashMappedTrieBase::getRoot() const { + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return PointerBase(); + return PointerBase(&Impl->Root); +} + +unsigned ThreadSafeHashMappedTrieBase::getStartBit( + ThreadSafeHashMappedTrieBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + if (auto *S = dyn_cast((TrieNode*)P.P)) + return S->StartBit; + return 0; +} + +unsigned ThreadSafeHashMappedTrieBase::getNumBits( + ThreadSafeHashMappedTrieBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + if (auto *S = dyn_cast((TrieNode*)P.P)) + return S->NumBits; + return 0; +} + +unsigned ThreadSafeHashMappedTrieBase::getNumSlotUsed( + ThreadSafeHashMappedTrieBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return 0; + auto *S = dyn_cast((TrieNode*)P.P); + if (!S) + return 0; + unsigned Num = 0; + for (unsigned I = 0, E = S->Slots.size(); I < E; ++I) + if (auto *E = S->Slots[I].load()) + ++Num; + return Num; +} + +unsigned ThreadSafeHashMappedTrieBase::getNumTries() const { + ImplType *Impl = ImplPtr.load(); + if (!Impl) + return 0; + unsigned Num = 0; + for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load()) + ++Num; + return Num; +} + +ThreadSafeHashMappedTrieBase::PointerBase +ThreadSafeHashMappedTrieBase::getNextTrie( + ThreadSafeHashMappedTrieBase::PointerBase P) const { + assert(!P.isHint() && "Not a valid trie"); + if (!P.P) + return PointerBase(); + auto *S = dyn_cast((TrieNode*)P.P); + if (!S) + return PointerBase(); + if (auto *E = S->Next.load()) + return PointerBase(E); + return PointerBase(); +} diff --git a/llvm/lib/Support/HashMappedTrieIndexGenerator.h b/llvm/lib/Support/HashMappedTrieIndexGenerator.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Support/HashMappedTrieIndexGenerator.h @@ -0,0 +1,89 @@ +//===- HashMappedTrieIndexGenerator.h ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_SUPPORT_HASHMAPPEDTRIEINDEXGENERATOR_H +#define LLVM_LIB_SUPPORT_HASHMAPPEDTRIEINDEXGENERATOR_H + +#include "llvm/ADT/ArrayRef.h" +#include + +namespace llvm { + +struct IndexGenerator { + size_t NumRootBits; + size_t NumSubtrieBits; + ArrayRef Bytes; + std::optional StartBit = std::nullopt; + + size_t getNumBits() const { + assert(StartBit); + size_t TotalNumBits = Bytes.size() * 8; + assert(*StartBit <= TotalNumBits); + return std::min(*StartBit ? NumSubtrieBits : NumRootBits, + TotalNumBits - *StartBit); + } + size_t next() { + size_t Index; + if (!StartBit) { + StartBit = 0; + Index = getIndex(Bytes, *StartBit, NumRootBits); + } else { + *StartBit += *StartBit ? NumSubtrieBits : NumRootBits; + assert((*StartBit - NumRootBits) % NumSubtrieBits == 0); + Index = getIndex(Bytes, *StartBit, NumSubtrieBits); + } + return Index; + } + + size_t hint(unsigned Index, unsigned Bit) { + assert(Index >= 0); + assert(Bit < Bytes.size() * 8); + assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0); + StartBit = Bit; + return Index; + } + + size_t getCollidingBits(ArrayRef CollidingBits) const { + assert(StartBit); + return getIndex(CollidingBits, *StartBit, NumSubtrieBits); + } + + static size_t getIndex(ArrayRef Bytes, size_t StartBit, + size_t NumBits) { + assert(StartBit < Bytes.size() * 8); + + Bytes = Bytes.drop_front(StartBit / 8u); + StartBit %= 8u; + size_t Index = 0; + for (uint8_t Byte : Bytes) { + size_t ByteStart = 0, ByteEnd = 8; + if (StartBit) { + ByteStart = StartBit; + Byte &= (1u << (8 - StartBit)) - 1u; + StartBit = 0; + } + size_t CurrentNumBits = ByteEnd - ByteStart; + if (CurrentNumBits > NumBits) { + Byte >>= CurrentNumBits - NumBits; + CurrentNumBits = NumBits; + } + Index <<= CurrentNumBits; + Index |= Byte & ((1u << CurrentNumBits) - 1u); + + assert(NumBits >= CurrentNumBits); + NumBits -= CurrentNumBits; + if (!NumBits) + break; + } + return Index; + } +}; + +} // namespace llvm + +#endif // LLVM_LIB_SUPPORT_HASHMAPPEDTRIEINDEXGENERATOR_H diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -31,6 +31,7 @@ FoldingSet.cpp FunctionExtrasTest.cpp FunctionRefTest.cpp + HashMappedTrieTest.cpp HashingTest.cpp IListBaseTest.cpp IListIteratorTest.cpp diff --git a/llvm/unittests/ADT/HashMappedTrieTest.cpp b/llvm/unittests/ADT/HashMappedTrieTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/HashMappedTrieTest.cpp @@ -0,0 +1,392 @@ +//===- HashMappedTrieTest.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/HashMappedTrie.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/SHA1.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { +template class HashMappedTrieInspector { +public: + HashMappedTrieInspector(const ThreadSafeHashMappedTrie &Trie) + : Trie(Trie) {} + + unsigned getRootNodeUsage() const { + return Trie.getNumSlotUsed(Trie.getRoot()); + } + + unsigned getNumTries() const { + return Trie.getNumTries(); + } + + // The last allocated trie is the next in the allocation chain. + unsigned getLastAllocatedTrieStartBit() const { + return Trie.getStartBit(Trie.getNextTrie(Trie.getRoot())); + } + + unsigned getLastAllocatedTrieNumBits() const { + return Trie.getNumBits(Trie.getNextTrie(Trie.getRoot())); + } + + unsigned getLastAllocatedTrieUsage() const { + return Trie.getNumSlotUsed(Trie.getNextTrie(Trie.getRoot())); + } + +private: + const ThreadSafeHashMappedTrie &Trie; +}; + +template +static HashMappedTrieInspector +makeTrieInspector(const ThreadSafeHashMappedTrie &Trie) { + return HashMappedTrieInspector(Trie); +} + +TEST(HashMappedTrieTest, TrieInspector) { + using NumType = uint64_t; + using HashType = std::array; + using TrieType = ThreadSafeHashMappedTrie; + NumType Numbers[] = { + 0x0, + std::numeric_limits::max(), + 0x1, + 0x2, + 0x3, + std::numeric_limits::max() - 1u, + }; + + unsigned ExpectedTries[] = { + 1, // Allocate Root. + 1, // Both on the root. + 64, // 0 and 1 sinks all the way down. + 64, // no new allocation needed. + 65, // need a new node between 2 and 3. + 65 + 63, // 63 new allocation to sink two big numbers all the way. + }; + + // Use the number itself as hash to test the pathological case. + auto hash = [](NumType Num) { + NumType HashN = llvm::support::endian::byte_swap(Num, llvm::support::big); + HashType Hash; + memcpy(&Hash[0], &HashN, sizeof(HashType)); + return Hash; + }; + + // Use root and subtrie sizes of 1 so this gets sunk quite deep. + TrieType Trie(1, 1); + auto Inspector = makeTrieInspector(Trie); + + for (unsigned I = 0; I < 4; ++I) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(Numbers[I])); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(Numbers[I]), Numbers[I])); + + EXPECT_EQ(Inspector.getNumTries(), ExpectedTries[I]); + } +} + +TEST(HashMappedTrieTest, TrieStructure) { + using NumType = uint64_t; + using HashType = std::array; + using TrieType = ThreadSafeHashMappedTrie; + NumType Numbers[] = { + // Three numbers that will nest deeply to test (1) sinking subtries and + // (2) deep, non-trivial hints. + std::numeric_limits::max(), + std::numeric_limits::max() - 2u, + std::numeric_limits::max() - 3u, + // One number to stay at the top-level. + 0x37, + }; + + // Use the number itself as hash to test the pathological case. + auto hash = [](NumType Num) { + NumType HashN = llvm::support::endian::byte_swap(Num, llvm::support::big); + HashType Hash; + memcpy(&Hash[0], &HashN, sizeof(HashType)); + return Hash; + }; + + // Use root and subtrie sizes of 1 so this gets sunk quite deep. + TrieType Trie(1, 1); + for (NumType N : Numbers) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(N), N)); + } + for (NumType N : Numbers) { + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(hash(N), Lookup->Hash); + EXPECT_EQ(N, Lookup->Data); + + // Confirm a subsequent insertion fails to overwrite by trying to insert a + // bad value. + EXPECT_EQ(N, + Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1))->Data); + } + + // Check the trie so we can confirm the structure is correct. Each subtrie + // should have 2 slots. The root's index=0 should have the content for + // 0x37 directly, and index=1 should be a linked-list of subtries, finally + // ending with content for (max-2) and (max-3). + // + // Note: This structure is not exhaustive (too expensive to update tests), + // but it does test that the dump format is somewhat readable and that the + // basic structure is correct. + // + // Note: This test requires that the trie reads bytes starting from index 0 + // of the array of uint8_t, and then reads each byte's bits from high to low. + + // Check the Trie. + auto Inspector = makeTrieInspector(Trie); + // We should allocated a total of 64 SubTries for 64 bit hash. + ASSERT_EQ(Inspector.getNumTries(), 64u); + // Check the root trie. Two slots and both are used. + ASSERT_EQ(Inspector.getRootNodeUsage(), 2u); + // Check last subtrie. + ASSERT_EQ(Inspector.getLastAllocatedTrieStartBit(), 63u); + ASSERT_EQ(Inspector.getLastAllocatedTrieNumBits(), 1u); + ASSERT_EQ(Inspector.getLastAllocatedTrieUsage(), 2u); +} + +TEST(HashMappedTrieTest, TrieStructureSmallFinalSubtrie) { + using NumType = uint64_t; + using HashType = std::array; + using TrieType = ThreadSafeHashMappedTrie; + NumType Numbers[] = { + // Three numbers that will nest deeply to test (1) sinking subtries and + // (2) deep, non-trivial hints. + std::numeric_limits::max(), + std::numeric_limits::max() - 2u, + std::numeric_limits::max() - 3u, + // One number to stay at the top-level. + 0x37, + }; + + // Use the number itself as hash to test the pathological case. + auto hash = [](NumType Num) { + NumType HashN = llvm::support::endian::byte_swap(Num, llvm::support::big); + HashType Hash; + memcpy(&Hash[0], &HashN, sizeof(HashType)); + return Hash; + }; + + // Use subtrie size of 7 to avoid hitting 64 evenly, making the final subtrie + // small. + TrieType Trie(8, 5); + for (NumType N : Numbers) { + // Lookup first to exercise hint code for deep tries. + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_FALSE(Lookup); + + Trie.insert(Lookup, TrieType::value_type(hash(N), N)); + } + for (NumType N : Numbers) { + TrieType::pointer Lookup = Trie.find(hash(N)); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(hash(N), Lookup->Hash); + EXPECT_EQ(N, Lookup->Data); + + // Confirm a subsequent insertion fails to overwrite by trying to insert a + // bad value. + EXPECT_EQ(N, + Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1))->Data); + } + + // Check the trie so we can confirm the structure is correct. The root + // should have 2^8=256 slots, most subtries should have 2^5=32 slots, and the + // deepest subtrie should have 2^1=2 slots (since (64-8)mod(5)=1). + // should have 2 slots. The root's index=0 should have the content for + // 0x37 directly, and index=1 should be a linked-list of subtries, finally + // ending with content for (max-2) and (max-3). + // + // Note: This structure is not exhaustive (too expensive to update tests), + // but it does test that the dump format is somewhat readable and that the + // basic structure is correct. + // + // Note: This test requires that the trie reads bytes starting from index 0 + // of the array of uint8_t, and then reads each byte's bits from high to low. + + // Check the Trie. + auto Inspector = makeTrieInspector(Trie); + // 64 bit hash = 8 + 5 * 11 + 1, so 1 root, 11 8bit subtrie and 1 last level + // subtrie, 13 total. + ASSERT_EQ(Inspector.getNumTries(), 13u); + // Check the root trie. Two slots are used. + ASSERT_EQ(Inspector.getRootNodeUsage(), 2u); + // Check last subtrie. The last allocated subtrie is the 1 bit subtrie. + ASSERT_EQ(Inspector.getLastAllocatedTrieStartBit(), 63u); + ASSERT_EQ(Inspector.getLastAllocatedTrieNumBits(), 1u); + ASSERT_EQ(Inspector.getLastAllocatedTrieUsage(), 2u); +} + +TEST(HashMappedTrieTest, TrieDestructionLoop) { + // Test destroying large Trie. Make sure there is no recursion that can + // overflow the stack. + using NumT = uint64_t; + struct NumWithDestructorT { + NumT Num; + operator NumT() const { return Num; } + ~NumWithDestructorT() {} + }; + + using HashT = std::array; + using TrieT = ThreadSafeHashMappedTrie; + using TrieWithDestructorT = + ThreadSafeHashMappedTrie; + + // Use the number itself in big-endian order as the hash. + auto hash = [](NumT Num) { + NumT HashN = llvm::support::endian::byte_swap(Num, llvm::support::big); + HashT Hash; + memcpy(&Hash[0], &HashN, sizeof(HashT)); + return Hash; + }; + + // Use optionals to control when destructors are called. + std::optional Trie; + std::optional TrieWithDestructor; + + // Limit the tries to 2 slots (1 bit) to generate subtries at a higher rate. + Trie.emplace(/*NumRootBits=*/1, /*NumSubtrieBits=*/1); + TrieWithDestructor.emplace(/*NumRootBits=*/1, /*NumSubtrieBits=*/1); + + // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug + // builds. + static constexpr uint64_t MaxN = 100000; + for (uint64_t N = 0; N != MaxN; ++N) { + HashT Hash = hash(N); + Trie->insert(TrieT::pointer(), TrieT::value_type(Hash, N)); + TrieWithDestructor->insert( + TrieWithDestructorT::pointer(), + TrieWithDestructorT::value_type(Hash, NumWithDestructorT{N})); + } + + // Destroy tries. If destruction is recursive and MaxN is high enough, these + // will both fail. + Trie.reset(); + TrieWithDestructor.reset(); +} + +namespace { +static constexpr unsigned HashSize = 20; +using HashType = std::array; + +class MockTrieStringSet + : ThreadSafeHashMappedTrie { +public: + using TrieType = typename MockTrieStringSet::ThreadSafeHashMappedTrie; + using LazyValueConstructor = typename MockTrieStringSet:: + ThreadSafeHashMappedTrie::LazyValueConstructor; + + class pointer : public TrieType::const_pointer { + using BaseType = typename TrieType::const_pointer; + + public: + const std::string &operator*() const { + return TrieType::const_pointer::operator*().Data; + } + const std::string *operator->() const { return &operator*(); } + + pointer() = default; + pointer(pointer &&) = default; + pointer(const pointer &) = default; + pointer &operator=(pointer &&) = default; + pointer &operator=(const pointer &) = default; + + private: + pointer(BaseType Result) : BaseType(Result) {} + friend class MockTrieStringSet; + }; + + MockTrieStringSet(std::optional NumRootBits = std::nullopt, + std::optional NumSubtrieBits = std::nullopt) + : TrieType(NumRootBits, NumSubtrieBits) {} + + static HashType hash(const std::string &V) { + // Mock hash function. Create unique hash for the test case inputs below. + const unsigned MaxSize = 4; + assert(V.size() <= MaxSize && "Invalid Input"); + std::string Input = V; + Input.resize(MaxSize); + HashType Hash; + // Repeat the padded string to hash length so SubTries will be created with + // different test configuration. + for (unsigned I = 0; I < HashSize; I += MaxSize) + llvm::copy(Input, Hash.begin() + I); + return Hash; + } + + pointer find(const std::string &Value) const { + return pointer(TrieType::find(hash(Value))); + } + pointer insert(pointer Hint, std::string &&Value) { + return pointer(TrieType::insertLazy( + typename pointer::BaseType(Hint), hash(Value), + [&](LazyValueConstructor C) { C(std::move(Value)); })); + } + pointer insert(pointer Hint, const std::string &Value) { + return pointer( + TrieType::insertLazy(typename pointer::BaseType(Hint), hash(Value), + [&](LazyValueConstructor C) { C(Value); })); + } + pointer insert(std::string &&Value) { return insert(pointer(), Value); } + pointer insert(const std::string &Value) { return insert(pointer(), Value); } +}; +} // end anonymous namespace + +TEST(HashMappedTrieTest, Strings) { + for (unsigned RootBits : {2, 3, 6, 10}) { + for (unsigned SubtrieBits : {2, 3, 4}) { + MockTrieStringSet Strings(RootBits, SubtrieBits); + const std::string &A1 = *Strings.insert("A"); + EXPECT_EQ(&A1, &*Strings.insert("A")); + std::string A2 = A1; + EXPECT_EQ(&A1, &*Strings.insert(A2)); + + const std::string &B1 = *Strings.insert("B"); + EXPECT_EQ(&B1, &*Strings.insert(B1)); + std::string B2 = B1; + EXPECT_EQ(&B1, &*Strings.insert(B2)); + + for (int I = 0, E = 1000; I != E; ++I) { + MockTrieStringSet::pointer Lookup; + std::string S = Twine(I).str(); + if (I & 1) + Lookup = Strings.find(S); + const std::string &S1 = *Strings.insert(Lookup, S); + EXPECT_EQ(&S1, &*Strings.insert(S1)); + std::string S2 = S1; + EXPECT_EQ(&S1, &*Strings.insert(S2)); + } + for (int I = 0, E = 1000; I != E; ++I) { + std::string S = Twine(I).str(); + MockTrieStringSet::pointer Lookup = Strings.find(S); + EXPECT_TRUE(Lookup); + if (!Lookup) + continue; + EXPECT_EQ(S, *Lookup); + } + } + } +} + +} // namespace