diff --git a/llvm/include/llvm/ADT/ConcurrentHashtable.h b/llvm/include/llvm/ADT/ConcurrentHashtable.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/ConcurrentHashtable.h @@ -0,0 +1,930 @@ +//===- ConcurrentHashtable.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_CONCURRENTHASHTABLE_H +#define LLVM_ADT_CONCURRENTHASHTABLE_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Parallel.h" +#include "llvm/Support/WithColor.h" +#include +#include +#include +#include +#include +#include + +namespace llvm { + +/// ConcurrentHashTable - is a resizeable concurrent hashtable. +/// The default range of resizings without noticable performance +/// degradations is up to x2^32. It could be changed by ChainPtrFreeBits +/// property. The hashtable allows only concurrent insertions. +/// (Though deletions could be easily added). + +/// Data structure: +/// +/// Inserted value is mapped to 64-bit hash value -> +/// +/// [------- 64-bit Hash value --------] +/// [ ChainIndex ][ Bucket Index ] +/// | | +/// points to the points to +/// bucket chains. the bucket. +/// +/// After initialization, all buckets consist of one chain(not inititalized +/// at start). During insertions, buckets might be extended to contain more +/// chains. The number of chains kept by bucket is called bucket width. Each +/// bucket can be independently resized and rehashed(no need to lock the whole +/// table). Different buckets may have different widths. +/// +/// HashTablesSet keeps chains of all buckets: +/// +/// HashTablesSet[ChainIdx][BucketIdx]: +/// +/// [ Bucket 0 ][ Bucket 1 ][Bucket ...][ Bucket M ] +/// [Chain 0] [Chain head][Chain head][Chain head][Chain head] +/// [Chain 1] [Chain head][Chain head][Chain head][Chain head] +/// [Chain 2] [Chain head][Chain head][Chain head][Chain head] +/// .......... variable size ............... +/// [Chain N] [Chain head][Chain head][Chain head][Chain head] +/// +/// Last bucket index == HashTableSize - 1 +/// Last chain index == BucketWidth - 1 +/// +/// Each chain keeps a list of array of entries: +/// +/// ChainHead->[ entry1 ] keeps entry data +/// [ entry2 ] +/// [ ... ] +/// [ last entry] +/// [ next chain item] -> points to the next chain item +/// +/// Pointer to the first bucket chain also keeps the width of the bucket. +/// i.e. HashTables[0][BucketIdx] keeps both: pointer to the first chain +/// and the width of the bucket BucketIdx. +/// +/// ConcurrentHashTable: each entry keeps KeyDataTy in place or null. +/// +/// ConcurrentHashTableByPtr: each entry keeps a pair: hash value +/// and the pointer to the KeyDataTy or null. + +template struct HashedEntry { + uint64_t Hash; + KeyDataTy *Data = nullptr; +}; + +template +class ConcurrentHashTableConstants { +public: + // Define the range of increasing of the size of HashTablesSet. + // Size = Size * 2^GrowthRate + static size_t constexpr GrowthRate = 1; + + // Define the size of the chain item. + static size_t constexpr ChainItemSize = + 128 * sizeof(typename std::conditional::type); + + // Define the number of mutexes. + static size_t constexpr MutexesInitialSize = 256; + + // Define the number of bits which will be used to keep bucket width. + static size_t constexpr ChainPtrFreeBits = 5; +}; + +/// AllocatorRefTy type allows to get access either to: +/// +/// 1. thread local allocator. +/// +/// static LLVM_THREAD_LOCAL BumpPtrAllocator DataAllocator; +/// class AllocatorRef { +/// public: +/// static inline BumpPtrAllocator& getAllocatorRef() { +/// return DataAllocator; +/// } +/// }; +/// +/// ASSUMPTION: ThreadPool is used to keep threads alive and so to keep +/// data. +/// +/// 2. thread safe allocator. +/// +/// static ThreadSafeBumpPtrAllocator DataAllocator; +/// class AllocatorRef { +/// public: +/// static inline BumpPtrAllocator& getAllocatorRef() { +/// return DataAllocator; +/// } +/// }; +/// + +template > +class ConcurrentHashTableBase { +public: + static_assert((Constants::GrowthRate > 0), + "GrowthRate must be greater than 0."); + static_assert((Constants::MutexesInitialSize > 0), + "MutexesInitialSize must be greater than 0."); + + using MapEntryTy = typename std::conditional>::type; + using InsertionTy = + typename std::conditional::type; + + using InsertionResultTy = + typename std::conditional::type; + + ConcurrentHashTableBase( + size_t EstimatedSize, + size_t ThreadsNum = parallel::strategy.compute_thread_count()) { + assert(ThreadsNum > 0); + + // Calculate hashtable size. + HashTableSize = EstimatedSize / EntriesPerChainItem; + HashTableSize = std::max(HashTableSize, (size_t)16); + HashTableSize = NextPowerOf2(HashTableSize); + NumberOfHashTables = 0; + + // Allocate first hashtable. + allocateNewHashTables(1); + + // Calculate number of mutexes. + BucketMutexesNum = NextPowerOf2(Constants::MutexesInitialSize * ThreadsNum); + BucketMutexesNum = std::min(BucketMutexesNum, HashTableSize); + + // Allocate mutexes. + BucketMutexes = + AllocatorRefTy::getAllocatorRef().template Allocate( + BucketMutexesNum); + for (size_t Idx = 0; Idx < BucketMutexesNum; Idx++) + new (&BucketMutexes[Idx]) std::mutex(); + + // Calculate masks. + size_t LeadingZerosNumber = countLeadingZeros(getHashMask()); + HashBitsNum = sizeof(hash_code) * UINT8_WIDTH - LeadingZerosNumber; + + // We have no more than (ChainItemAlign - 1) bits to keep bucket width. + MaxBucketWidth = 1 << std::min((ChainItemAlign - 1), LeadingZerosNumber); + + // Calculate mask for extended hash bits. + ExtHashMask = ((HashTableSize * MaxBucketWidth) - 1) & ~getHashMask(); + } + + /// Insert new value \p NewValue or return already existing entry. + /// + /// \returns entry and "true" if an entry is just inserted or + /// "false" if an entry already exists. + std::pair insert(const InsertionTy &NewValue) { + // Calculate bucket index. + hash_code Hash = getHashCode(NewValue); + size_t BucketIdx = Hash & getHashMask(); + size_t ExtHashBits = getExtHashBits(Hash); + MapEntryTy DataToInsert; + + ValueInserter Handler(NewValue, + getBucketMutex(BucketIdx), Hash); + HashTableEntry &ZeroEntryRef = HashTablesSet[0][BucketIdx]; + + // Lock bucket. + Handler.BucketMutex.lock(); + + while (true) { + // Get chain head. + ChainItem *ChainHead = ZeroEntryRef; + ExtBitsEntry ExtEntry = ExtBitsEntry::getFromOpaqueValue(ChainHead); + uint8_t BucketWidthValue = getBucketWidthValue(ExtEntry); + size_t BucketWidth = getBucketWidthFromValue(BucketWidthValue); + size_t ChainIdx = ExtHashBits & (BucketWidth - 1); + if (ChainIdx == 0) { + ChainHead = ExtEntry.getPointer(); + if (ChainHead == nullptr) { + ChainHead = allocateItem(); + ExtEntry.setPointer(ChainHead); + ZeroEntryRef = static_cast(ExtEntry.getOpaqueValue()); + } + } else { + HashTableEntry &EntryRef = HashTablesSet[ChainIdx][BucketIdx]; + ChainHead = EntryRef; + + if (ChainHead == nullptr) { + ChainHead = allocateItem(); + EntryRef = ChainHead; + } + } + + // For each chain entry... + if (forEachChainEntry>( + BucketIdx, BucketWidth, ChainHead, Handler)) + break; + } + + return Handler.Result; + } + + /// Print information about current state of hash table structures. + void printStatistic(raw_ostream &OS) { + OS << "\n--- HashTable statistic:\n"; + OS << "\nTable Size = " << HashTableSize; + OS << "\nNumber of tables = " << NumberOfHashTables; + OS << "\nEntries per chain item = " << (int)EntriesPerChainItem; + OS << "\nNumber of mutexes = " << (int)BucketMutexesNum; + + uint64_t NumberOfBuckets = 0; + uint64_t LongestChainLength = 0; + uint64_t NumberOfEntries = 0; + uint64_t NumberOfEntriesPlusEmpty = 0; + uint64_t OverallSize = + sizeof(*this) + HashTableSize * NumberOfHashTables * sizeof(MapEntryTy); + + class ChainLengthCounter { + public: + IterationStatus handleEntry(MapEntryTy &EntryData) { + if (!ConcurrentHashTableBase::isNull(EntryData)) + ChainLength++; + + return IterationStatus::Next; + } + uint64_t ChainLength = 0; + }; + DenseMap BucketWidthsMap; + + // For each bucket... + for (uint64_t CurBucketIdx = 0; CurBucketIdx < HashTableSize; + CurBucketIdx++) { + + NumberOfEntriesPlusEmpty += EntriesPerChainItem * NumberOfHashTables; + + if (isEmptyBucket(CurBucketIdx)) + continue; + + NumberOfBuckets++; + + size_t BucketWidth = + getBucketWidthFromValue(getBucketWidthValue(CurBucketIdx)); + BucketWidthsMap[BucketWidth] = BucketWidthsMap.lookup(BucketWidth) + 1; + + // For each chain... + for (size_t CurChainIdx = 0; CurChainIdx < BucketWidth; CurChainIdx++) { + + ChainItem *ChainHead = + getChainHead(CurBucketIdx, CurChainIdx); + if (ChainHead == nullptr) + continue; + + ChainLengthCounter Handler; + + // For each chain entry... + [[maybe_unused]] bool ForEachSuccessful = + forEachChainEntry( + CurBucketIdx, BucketWidth, ChainHead, Handler); + assert(ForEachSuccessful); + + LongestChainLength = std::max(LongestChainLength, Handler.ChainLength); + NumberOfEntries += Handler.ChainLength; + + size_t ItemsInChain = + (Handler.ChainLength / EntriesPerChainItem) + + ((Handler.ChainLength % EntriesPerChainItem) > 0 ? 1 : 0); + assert(ItemsInChain > 0); + + NumberOfEntriesPlusEmpty += EntriesPerChainItem * (ItemsInChain - 1); + OverallSize += sizeof(ChainItem) * ItemsInChain; + if constexpr (std::is_pointer::value) + OverallSize += sizeof(KeyDataTy); + } + } + + OS << "\nOverall number of entries = " << NumberOfEntries; + OS << "\nOverall number of buckets = " << NumberOfBuckets; + for (auto &Width : BucketWidthsMap) + OS << "\n Number of buckets with width " << Width.first << ": " + << Width.second; + OS << "\nLongest chain length = " << LongestChainLength; + + std::stringstream stream; + stream << std::fixed << std::setprecision(2) + << ((float)NumberOfEntries / (float)NumberOfEntriesPlusEmpty); + std::string str = stream.str(); + + OS << "\nLoad factor = " << str; + OS << "\nOverall allocated size = " << OverallSize; + } + +protected: + enum IterationStatus : uint8_t { Stop, Next }; + + // Set the alignment of ChanItem so that we have ChainPtrFreeBits bits + // available; + static size_t constexpr ChainItemAlign = 1 << Constants::ChainPtrFreeBits; + + // The size of chain item entry. + static size_t constexpr ChainItemEntrySize = + sizeof(typename std::conditional>::type); + + static_assert((Constants::ChainItemSize >= + (sizeof(uintptr_t) + ChainItemEntrySize)), + "ChainItemSize must be enough to keep atleast one entry."); + + // The number of entries per chain item. + static size_t constexpr EntriesPerChainItem = + (Constants::ChainItemSize - sizeof(uintptr_t)) / ChainItemEntrySize; + + struct alignas(ChainItemAlign) ChainItem { + ChainItem() = delete; + ChainItem(const ChainItem &) = delete; + ChainItem &operator=(const ChainItem &) = delete; + + MapEntryTy Entries[EntriesPerChainItem]; + ChainItem *Next; + }; + + static constexpr bool CreateChainHead = true; + static constexpr bool DoNotCreateChainHead = false; + + static constexpr bool CreateNewChainItems = true; + static constexpr bool DoNotCreateNewChainItems = false; + + using HashTableEntry = ChainItem *; + using HashTablePtr = HashTableEntry *; + + // Store degree of bucket width in the free bits of pointer. + using ExtBitsEntry = + PointerIntPair; + + bool isEmptyBucket(size_t BucketIdx) { + return ExtBitsEntry::getFromOpaqueValue(HashTablesSet[0][BucketIdx]) + .getPointer() == nullptr; + } + + uint8_t getBucketWidthValue(size_t BucketIdx) { + return ExtBitsEntry::getFromOpaqueValue(HashTablesSet[0][BucketIdx]) + .getInt(); + } + + uint8_t getBucketWidthValue(ChainItem *Head) { + return getBucketWidthValue(ExtBitsEntry::getFromOpaqueValue(Head)); + } + + uint8_t getBucketWidthValue(ExtBitsEntry ExtEntry) { + return ExtEntry.getInt(); + } + + size_t getBucketWidthFromValue(uint8_t BucketWidthValue) { + return 1 << (BucketWidthValue); + } + + void setBucketWidth(size_t BucketIdx, size_t NewBucketWidth) { + assert((NewBucketWidth >= 1) & !(NewBucketWidth & (NewBucketWidth - 1))); + assert(getBucketWidthFromValue(getBucketWidthValue(BucketIdx)) < + NewBucketWidth); + + HashTableEntry &ExtEntryRef = HashTablesSet[0][BucketIdx]; + ExtBitsEntry ExtEntry = ExtBitsEntry::getFromOpaqueValue(ExtEntryRef); + ExtEntry.setInt(static_cast(countTrailingZeros(NewBucketWidth))); + + ExtEntryRef = static_cast(ExtEntry.getOpaqueValue()); + } + + size_t getExtHashBits(hash_code Hash) { + return (Hash & ExtHashMask) >> HashBitsNum; + } + + size_t getChainIdx(hash_code Hash, size_t BucketWidth) { + assert(BucketWidth > 0); + + return getExtHashBits(Hash) & (BucketWidth - 1); + } + + // Return number of entries inside specified bucket. + size_t getNumberOfEntries(size_t BucketIdx, size_t BucketWidth) { + class EntriesCounter { + public: + IterationStatus handleEntry(MapEntryTy &EntryData) { + if (ConcurrentHashTableBase::isNull(EntryData)) + return IterationStatus::Next; + + EntriesNumber++; + return IterationStatus::Next; + } + size_t EntriesNumber = 0; + } Handler; + + for (size_t CurChainIdx = 0; CurChainIdx < BucketWidth; CurChainIdx++) { + ChainItem *ChainHead = + getChainHead(BucketIdx, CurChainIdx); + if (ChainHead == nullptr) + continue; + + [[maybe_unused]] bool ForEachSuccessful = + forEachChainEntry( + BucketIdx, BucketWidth, ChainHead, Handler); + assert(ForEachSuccessful); + } + + return Handler.EntriesNumber; + } + + // Allocate new chain intem. + ChainItem *allocateItem() { + ChainItem *ResultItem = + AllocatorRefTy::getAllocatorRef().template Allocate(); + + assert(ResultItem != nullptr); + memset(ResultItem, 0x0, sizeof(ChainItem)); + + return ResultItem; + } + + // Return chain head. Create new head if CreateChainHead is true. + template + inline __attribute__((always_inline)) ChainItem * + getChainHead(size_t BucketIdx, hash_code Hash, size_t BucketWidth) { + size_t ChainIdx = getChainIdx(Hash, BucketWidth); + return getChainHead(BucketIdx, ChainIdx); + } + + // Return chain head. Create new head if CreateChainHead is true. + template + inline __attribute__((always_inline)) ChainItem * + getChainHead(size_t BucketIdx, size_t ChainIdx) { + ChainItem *ChainHead = HashTablesSet[ChainIdx][BucketIdx]; + + if (ChainIdx == 0) { + ExtBitsEntry ExtEntry = ExtBitsEntry::getFromOpaqueValue(ChainHead); + ChainHead = ExtEntry.getPointer(); + + if constexpr (CreateChainHead) { + if (ChainHead == nullptr) { + ChainHead = allocateItem(); + ExtEntry.setPointer(ChainHead); + HashTablesSet[ChainIdx][BucketIdx] = + static_cast(ExtEntry.getOpaqueValue()); + } + } + + return ChainHead; + } + + if constexpr (CreateChainHead) { + if (ChainHead == nullptr) { + ChainHead = allocateItem(); + HashTablesSet[ChainIdx][BucketIdx] = ChainHead; + } + } + + return ChainHead; + } + + // Return next chain item. Create new item if necessary. + template ::type * = nullptr> + inline __attribute__((always_inline)) ChainItem * + getNextChainItem(ChainItem *Item) { + ChainItem *NextItem = Item->Next; + + if (NextItem == nullptr) { + ChainItem *NewEntry = allocateItem(); + + Item->Next = NewEntry; + return NewEntry; + } + + return NextItem; + } + + // Return next chain item. + template ::type * = nullptr> + inline __attribute__((always_inline)) ChainItem * + getNextChainItem(ChainItem *Item) { + return Item->Next; + } + + // Extend HashTables and rehash specified bucket. + template ::type * = nullptr> + bool extendTable(size_t BucketIdx, size_t BucketWidth) { + assert(BucketWidth == + getBucketWidthFromValue(getBucketWidthValue(BucketIdx))); + if (BucketWidth >= MaxBucketWidth) + return false; + + ExtendAndRehashBucket(BucketIdx, BucketWidth); + return true; + } + + // Do nothing. + template ::type * = nullptr> + constexpr bool extendTable(size_t, size_t) { + return false; + } + + // Enumerate all entries for the specified bucket and bucket`s chain. + template + inline __attribute__((always_inline)) bool + forEachChainEntry(size_t BucketIdx, size_t BucketWidth, ChainItem *ChainHead, + EntryHandler &Handler) { + assert(ChainHead != nullptr); + + for (ChainItem *CurItem = ChainHead; CurItem != nullptr; + CurItem = getNextChainItem(CurItem)) { + MapEntryTy *Entries = CurItem->Entries; + for (size_t EntryIdx = 0; EntryIdx < EntriesPerChainItem;) { + switch (Handler.handleEntry(Entries[EntryIdx])) { + case Stop: + return true; + case Next: + EntryIdx++; + } + } + + if (extendTable(BucketIdx, BucketWidth)) + return false; + } + + return true; + } + + // Insert data into the chain item. + template class ValueInserter { + public: + ValueInserter(const InsertionTy &InsertionData, std::mutex &BucketMutex, + hash_code Hash) + : InsertionData(InsertionData), BucketMutex(BucketMutex), Hash(Hash) {} + + IterationStatus handleEntry(MapEntryTy &EntryData) { + MapEntryTy Value = EntryData; + if (ConcurrentHashTableBase::isNull(Value)) { + // Insert new value into the empty slot. + if constexpr (KeepDataInsideTable) { + Result.first = InsertionData; + EntryData = Result.first; + } else { + Result.first = + Info::create(InsertionData, AllocatorRefTy::getAllocatorRef()); + EntryData = {Hash, Result.first}; + } + assert(!(isNull(Result.first))); + BucketMutex.unlock(); + + Result.second = true; + return IterationStatus::Stop; + } + + if constexpr (KeepDataInsideTable) { + if (ConcurrentHashTableBase::isEqual(Value, InsertionData)) { + // Already existed entry matched with inserted data + // is found. + BucketMutex.unlock(); + Result.first = Value; + Result.second = false; + return IterationStatus::Stop; + } + } else { + if (ConcurrentHashTableBase::isEqual(Value, InsertionData, Hash)) { + // Already existed entry matched with inserted data + // is found. + BucketMutex.unlock(); + Result.first = Value.Data; + Result.second = false; + return IterationStatus::Stop; + } + } + + return IterationStatus::Next; + } + std::pair Result; + const InsertionTy &InsertionData; + std::mutex &BucketMutex; + hash_code Hash; + }; + + using CollectedValuesVector = SmallVector; + + // This function adds new chains to the bucket and rehash existing + // values to place them into the new chains. + void ExtendAndRehashBucket(size_t BucketIdx, size_t BucketWidth) { + assert(BucketWidth < MaxBucketWidth); + assert(BucketWidth > 0); + + assert(Constants::GrowthRate < countLeadingZeros(BucketWidth)); + size_t NewBucketsWidth = + std::min(static_cast(BucketWidth << Constants::GrowthRate), + MaxBucketWidth); + + // Check whether we need to create new hashtables. + if (NewBucketsWidth > NumberOfHashTables) { + std::unique_lock Guard(HashTableMutex); + + if (NewBucketsWidth > NumberOfHashTables) + allocateNewHashTables(NewBucketsWidth); + } + + // Collect bucket values. + CollectedValuesVector Values; + collectValuesLocked(BucketIdx, BucketWidth, Values); + + assert(getNumberOfEntries(BucketIdx, BucketWidth) == 0); + + // Insert values into extended bucket. + insertValuesLocked(Values, BucketIdx, NewBucketsWidth); + + // Update bucket width. + setBucketWidth(BucketIdx, NewBucketsWidth); + } + + // Collect all data into the specified vector Values. + // Erase data from all bucket's chains. + void collectValuesLocked(size_t BucketIdx, size_t BucketWidth, + CollectedValuesVector &Values) { + class ValueCollector { + public: + ValueCollector(CollectedValuesVector &Values) : Values(Values) {} + + IterationStatus handleEntry(MapEntryTy &EntryData) { + MapEntryTy Value = EntryData; + + if (ConcurrentHashTableBase::isNull(Value)) + return IterationStatus::Stop; + + Values.emplace_back(Value); + setToNull(&EntryData); + + return IterationStatus::Next; + } + + CollectedValuesVector &Values; + } Handler(Values); + + // Collect bucket values. + for (size_t CurChainIdx = 0; CurChainIdx < BucketWidth; CurChainIdx++) { + ChainItem *ChainHead = + getChainHead(BucketIdx, CurChainIdx); + if (ChainHead == nullptr) + continue; + + [[maybe_unused]] bool ForEachSuccessful = + forEachChainEntry( + BucketIdx, BucketWidth, ChainHead, Handler); + assert(ForEachSuccessful); + } + } + + // Insert data into the locked bucket. + void insertValuesLocked(CollectedValuesVector &Values, size_t BucketIdx, + size_t BucketWidth) { + for (MapEntryTy &NewValue : Values) { + assert(!(ConcurrentHashTableBase::isNull(NewValue))); + + hash_code Hash = ConcurrentHashTableBase::getHashCode(NewValue); + ChainItem *ChainHead = + getChainHead(BucketIdx, Hash, BucketWidth); + + class InsertEntry { + public: + InsertEntry(MapEntryTy NewValue) : NewValue(NewValue) {} + + IterationStatus handleEntry(MapEntryTy &EntryData) { + if (ConcurrentHashTableBase::isNull(EntryData)) { + EntryData = NewValue; + return IterationStatus::Stop; + } + + return IterationStatus::Next; + } + + MapEntryTy NewValue; + } Handler(NewValue); + + [[maybe_unused]] bool ForEachSuccessful = + forEachChainEntry( + BucketIdx, BucketWidth, ChainHead, Handler); + + assert(ForEachSuccessful); + } + } + + // Allocate new hashtables for specified NewBucketsWidth. + void allocateNewHashTables(size_t NewBucketsWidth) { + assert((NewBucketsWidth >= 1) & !(NewBucketsWidth & (NewBucketsWidth - 1))); + assert(NumberOfHashTables < NewBucketsWidth); + + AllocatorTy &Allocator = AllocatorRefTy::getAllocatorRef(); + + HashTablePtr *NewHashTablesSet = + Allocator.template Allocate(NewBucketsWidth); + // Allocate and initialize new tables, copy old tables. + // TODO: tables initialization might be done in parallel. + for (size_t Idx = 0; Idx < NewBucketsWidth; Idx++) { + if (Idx < NumberOfHashTables) { + NewHashTablesSet[Idx] = HashTablesSet[Idx]; + continue; + } + + HashTablePtr NewTable = + Allocator.template Allocate(HashTableSize); + memset(NewTable, 0, sizeof(HashTableEntry) * HashTableSize); + NewHashTablesSet[Idx] = NewTable; + } + + HashTablesSet = NewHashTablesSet; + NumberOfHashTables = NewBucketsWidth; + } + + // Return mutex for the specified BucketIdx. + std::mutex &getBucketMutex(size_t BucketIdx) { + return BucketMutexes[BucketIdx & (BucketMutexesNum - 1)]; + } + + template >::value), + bool>::type = true> + static inline bool isNull(MapEntryTy Data) { + const uint64_t *DataPtr = reinterpret_cast(&Data); + return *DataPtr == 0 && DataPtr[1] == 0; + } + + template >::value && + sizeof(MapEntryTy) == sizeof(uint8_t)), + bool>::type = true> + static inline bool isNull(const MapEntryTy &Data) { + return *reinterpret_cast(&Data) == static_cast(0); + } + + template >::value && + sizeof(MapEntryTy) == sizeof(uint16_t)), + bool>::type = true> + static inline bool isNull(const MapEntryTy &Data) { + return *reinterpret_cast(&Data) == + static_cast(0); + } + + template >::value && + sizeof(MapEntryTy) == sizeof(uint32_t)), + bool>::type = true> + static inline bool isNull(const MapEntryTy &Data) { + return *reinterpret_cast(&Data) == + static_cast(0); + } + + template >::value && + sizeof(MapEntryTy) == sizeof(uint64_t)), + bool>::type = true> + static inline bool isNull(const MapEntryTy &Data) { + return *reinterpret_cast(&Data) == + static_cast(0); + } + + static inline bool setToNull(MapEntryTy *Data) { + return memset(Data, 0, sizeof(MapEntryTy)); + } + + static inline bool isEqual(MapEntryTy LHS, const InsertionTy &RHS, + uint64_t Hash) { + return (LHS.Hash == Hash) && Info::isEqual(LHS.Data->key(), RHS); + } + + static inline bool isEqual(const MapEntryTy &LHS, const InsertionTy &RHS) { + return Info::isEqual(LHS.key(), RHS.key()); + } + + static inline hash_code getHashCode(const KeyTy &Data) { + return Info::getHashValue(Data); + } + + static inline hash_code getHashCode(const KeyDataTy &Data) { + return Info::getHashValue(Data.key()); + } + + static inline hash_code getHashCode(const HashedEntry &Data) { + return Data.Hash; + } + + size_t getHashMask() { return HashTableSize - 1; } + + // Number of bits in hash mask. + uint64_t HashBitsNum = 0; + + // Hash mask for the extended hash bits. + uint64_t ExtHashMask = 0; + + // Array of mutexes. + std::mutex *BucketMutexes = nullptr; + + // Number of used mutexes. + size_t BucketMutexesNum = 2; + + // The maximal bucket width. + size_t MaxBucketWidth = 0; + + // The size of single hashtable. + size_t HashTableSize = 0; + + // The guard for HashTables array. + std::mutex HashTableMutex; + + // The current number of HashTables. + std::atomic NumberOfHashTables; + + // HashTables keeping buckets chains. + std::atomic HashTablesSet = nullptr; +}; + +/// ConcurrentHashTable: keeps a KeyDataTy(which is a key-value pair). +/// The state when all data bits of KeyDataTy are zero is reserved as +/// a hashtable tombstone value. +/// + +template class ConcurrentHashTableInfo { +public: + /// \returns Hash value for the specified \p Key. + static inline hash_code getHashValue(const KeyTy &Key) { + return std::hash()(Key); + } + + /// \returns true if both \p LHS and \p RHS are equal. + static inline bool isEqual(const KeyTy &LHS, const KeyTy &RHS) { + return LHS == RHS; + } +}; + +template , + typename Constants = ConcurrentHashTableConstants> +class ConcurrentHashTable + : public ConcurrentHashTableBase { +public: + ConcurrentHashTable(size_t InitialSize) + : ConcurrentHashTableBase(InitialSize) {} +}; + +/// ConcurrentHashTableByPtr: keeps a pointer to the KeyDataTy class +/// (which is a key-value pair). + +template +class ConcurrentHashTableInfoByPtr { +public: + /// \returns Hash value for the specified \p Key. + static inline hash_code getHashValue(const KeyTy &Key) { + return std::hash()(Key); + } + + /// \returns true if both \p LHS and \p RHS are equal. + static inline bool isEqual(const KeyTy &LHS, const KeyTy &RHS) { + return LHS == RHS; + } + + /// \returns newly created object of KeyDataTy type. + static inline KeyDataTy *create(const KeyTy &Key, AllocatorTy &Allocator) { + return KeyDataTy::create(Key, Allocator); + } +}; + +template , + typename Constants = ConcurrentHashTableConstants> +class ConcurrentHashTableByPtr + : public ConcurrentHashTableBase { +public: + ConcurrentHashTableByPtr(size_t InitialSize) + : ConcurrentHashTableBase(InitialSize) {} +}; + +} // end namespace llvm + +#endif // LLVM_ADT_CONCURRENTHASHTABLE_H diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -17,6 +17,7 @@ BumpPtrListTest.cpp CoalescingBitVectorTest.cpp CombinationGeneratorTest.cpp + ConcurrentHashtableTest.cpp DAGDeltaAlgorithmTest.cpp DeltaAlgorithmTest.cpp DenseMapTest.cpp diff --git a/llvm/unittests/ADT/ConcurrentHashtableTest.cpp b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp @@ -0,0 +1,423 @@ +//===- ConcurrentHashtableTest.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ConcurrentHashtable.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Parallel.h" +#include "gtest/gtest.h" +#include +#include +#include +using namespace llvm; + +namespace { +class Int { +public: + Int() : Data(0x0) {} + uint32_t key() const { return Data & 0x7FFFFFFF; } + + friend bool operator==(const Int &LHS, const Int &RHS) { + return LHS.Data == RHS.Data; + } + + static Int create(uint32_t Data) { return Int(Data | 0x80000000); } + +protected: + Int(uint32_t Data) : Data(Data) {} + + uint32_t Data; +}; + +TEST(ConcurrentHashTableTest, AddIntEntries) { + static BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTable HashTable(10); + + std::pair res1 = HashTable.insert(Int::create(1)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == 1); + EXPECT_TRUE(res1.second); + + res1 = HashTable.insert(Int::create(1)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == 1); + EXPECT_FALSE(res1.second); + + std::pair res2 = HashTable.insert(Int::create(2)); + // Check old entry is still valid. + EXPECT_TRUE(res1.first.key() == 1); + // Check new entry is inserted. + EXPECT_TRUE(res2.first.key() == 2); + EXPECT_TRUE(res2.second); + // Check new and old entries are not the equal. + EXPECT_FALSE(res1.first == res2.first); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 2") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddEntriesWithResizing) { + static BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTable HashTable(10); + + for (size_t Idx = 0; Idx < 10000; Idx++) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_TRUE(res1.second); + + res1 = HashTable.insert(Int::create(Idx)); + // Check entry is found. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + } + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + for (size_t Idx = 0; Idx < 10000; Idx++) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + // Check entry is found. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + } + + StatisticString.erase(); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddEntriesParralel) { + static LLVM_THREAD_LOCAL BumpPtrAllocator DataAllocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return DataAllocator; } + }; + ConcurrentHashTable HashTable(1000000); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_TRUE(res1.second); + + res1 = HashTable.insert(Int::create(Idx)); + // Check entry is found. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + }); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + }); + + StatisticString.erase(); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddEntriesWithResizingParralel1) { + static LLVM_THREAD_LOCAL BumpPtrAllocator DataAllocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return DataAllocator; } + }; + ConcurrentHashTable HashTable(1000); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_TRUE(res1.second); + + res1 = HashTable.insert(Int::create(Idx)); + // Check entry is found. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + }); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + }); + + StatisticString.erase(); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddEntriesWithResizingParralel2) { + static LLVM_THREAD_LOCAL BumpPtrAllocator DataAllocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return DataAllocator; } + }; + ConcurrentHashTable HashTable(10); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_TRUE(res1.second); + + res1 = HashTable.insert(Int::create(Idx)); + // Check entry is found. + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + }); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Int::create(Idx)); + EXPECT_TRUE(res1.first.key() == Idx); + EXPECT_FALSE(res1.second); + }); + + StatisticString.erase(); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddRandomEntriesParralel) { + std::vector Data; + std::set UniqData; + std::linear_congruential_engine rnd( + std::chrono::system_clock::now().time_since_epoch().count()); + for (size_t Idx = 0; Idx < 10000; Idx++) { + size_t RndNum = rnd(); + Data.push_back(Int::create(RndNum)); + UniqData.insert(RndNum); + } + + static LLVM_THREAD_LOCAL BumpPtrAllocator DataAllocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return DataAllocator; } + }; + ConcurrentHashTable HashTable(10); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Data[Idx]); + //// Check entry is inserted. + EXPECT_TRUE(res1.first.key() == Data[Idx].key()); + + res1 = HashTable.insert(Data[Idx]); + //// Check entry is inserted. + EXPECT_TRUE(res1.first.key() == Data[Idx].key()); + EXPECT_FALSE(res1.second); + }); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = " + + itostr(UniqData.size())) != + std::string::npos); + + parallelFor(0, 10000, [&](size_t Idx) { + std::pair res1 = HashTable.insert(Data[Idx]); + EXPECT_TRUE(res1.first.key() == Data[Idx].key()); + EXPECT_FALSE(res1.second); + }); + + StatisticString.erase(); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = " + + itostr(UniqData.size())) != + std::string::npos); +} + +using ExtraDataTy = std::array; + +class String { +public: + String() {} + const std::string &key() const { return Data; } + + static String *create(const std::string &Num, BumpPtrAllocator &Allocator) { + String *Result = Allocator.Allocate(); + new (Result) String(Num); + return Result; + } + +protected: + String(const std::string &Num) { Data += Num; } + + std::string Data; + ExtraDataTy ExtraData; +}; + +TEST(ConcurrentHashTableTest, AddStringEntries) { + static BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTableByPtr HashTable(10); + + std::pair res1 = HashTable.insert("1"); + // Check entry is inserted. + EXPECT_TRUE(res1.first->key() == "1"); + EXPECT_TRUE(res1.second); + + std::pair res2 = HashTable.insert("2"); + // Check old entry is still valid. + EXPECT_TRUE(res1.first->key() == "1"); + // Check new entry is inserted. + EXPECT_TRUE(res2.first->key() == "2"); + EXPECT_TRUE(res2.second); + // Check new and old entries use different memory. + EXPECT_TRUE(res1.first != res2.first); + + std::pair res3 = HashTable.insert("3"); + // Check one more entry is inserted. + EXPECT_TRUE(res3.first->key() == "3"); + EXPECT_TRUE(res3.second); + + std::pair res4 = HashTable.insert("1"); + // Check duplicated entry is inserted. + EXPECT_TRUE(res4.first->key() == "1"); + EXPECT_FALSE(res4.second); + // Check duplicated entry matches with the first one. + EXPECT_TRUE(res1.first == res4.first); + + // Check first entry is still valid. + EXPECT_TRUE(res1.first->key() == "1"); + + // Check data was allocated by allocator. + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("Overall number of entries = 3") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddStringEntriesParallel) { + // Number of elements exceeds original size, thus hashtable + // should be resized. + const size_t NumElements = 10000; + static LLVM_THREAD_LOCAL BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTableByPtr HashTable( + NumElements); + + // Check parallel insertion. + parallelFor(0, NumElements, [&](size_t I) { + std::string StringForElement = formatv("{0}", I); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first->key() == StringForElement); + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + }); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + // Check parallel insertion of duplicates. + parallelFor(0, NumElements, [&](size_t I) { + size_t BytesAllocated = Allocator.getBytesAllocated(); + std::string StringForElement = formatv("{0}", I); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first->key() == StringForElement); + // Check no additional bytes were allocated for duplicate. + EXPECT_TRUE(Allocator.getBytesAllocated() == BytesAllocated); + }); + + // Check statistic. + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +} // namespace