diff --git a/llvm/include/llvm/ADT/ConcurrentHashtable.h b/llvm/include/llvm/ADT/ConcurrentHashtable.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/ConcurrentHashtable.h @@ -0,0 +1,773 @@ +//===- ConcurrentHashtable.h ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_CONCURRENTHASHTABLE_H +#define LLVM_ADT_CONCURRENTHASHTABLE_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Parallel.h" +#include "llvm/Support/WithColor.h" +#include +#include +#include +#include +#include +#include + +namespace llvm { + +/// ConcurrentHashTable - is a resizeable concurrent hashtable. +/// The number of resizings limited to x2^32. The hashtable allows +/// only concurrent insertions(Though deletions could be easily added). +/// +/// InsertedValue = insert ( NewValue ); +/// +/// Data structure: +/// +/// Inserted value is mapped to 64-bit hash value -> +/// +/// [------- 64-bit Hash value --------] +/// [ ChainIndex ][ Bucket Index ] +/// | | +/// points to the points to +/// bucket chains. the bucket. +/// +/// After initialization, all buckets consist of one chain(not inititalized +/// at start). During insertions, buckets might be extended to contain more +/// chains. The number of chains kept by bucket is called bucket width. Each +/// bucket can be independently resized and rehashed(no need to lock the whole +/// table). Different buckets may have different widths. +/// +/// HashTablesSet keeps chains of all buckets: +/// +/// HashTablesSet[ChainIdx][BucketIdx]: +/// +/// [ Bucket 0 ][ Bucket 1 ][Bucket ...][ Bucket M ] +/// [Chain 0] [Chain head][Chain head][Chain head][Chain head] +/// [Chain 1] [Chain head][Chain head][Chain head][Chain head] +/// [Chain 2] [Chain head][Chain head][Chain head][Chain head] +/// .......... variable size ............... +/// [Chain N] [Chain head][Chain head][Chain head][Chain head] +/// +/// Last bucket index == HashTableSize - 1 +/// Last chain index == BucketWidth - 1 +/// +/// Each chain keeps an array of EntriesPerChainItem entries: +/// +/// ChainHead->[ KeyDataTy* data ][ uint32_t high part of hash ] +/// [ KeyDataTy* data ][ uint32_t high part of hash ] +/// [ KeyDataTy* data ][ uint32_t high part of hash ] +/// [ KeyDataTy* data ][ uint32_t high part of hash ] +/// +/// Pointer to the each bucket chain also keeps the state of the chain: +/// joined or disjoined. + +template +class ConcurrentHashTableInfoByPtr { +public: + /// \returns Hash value for the specified \p Key. + static inline hash_code getHashValue(const KeyTy &Key) { + return std::hash()(Key); + } + + /// \returns true if both \p LHS and \p RHS are equal. + static inline bool isEqual(const KeyTy &LHS, const KeyTy &RHS) { + return LHS == RHS; + } + + /// \returns key for the specified \p KeyData. + static inline const KeyTy &getKey(const KeyDataTy &KeyData) { + return KeyData.getKey(); + } + + /// \returns newly created object of KeyDataTy type. + static inline KeyDataTy *create(const KeyTy &Key, AllocatorTy &Allocator) { + return KeyDataTy::create(Key, Allocator); + } +}; + +class ConcurrentHashTableConstants { +public: + // Define the number of entries per chain item. + static size_t constexpr EntriesPerChainItem = 32; + + // Define the number of mutexes. + static size_t constexpr MutexesInitialSize = 512; + + // Specify whether buckets should be rehashed with small steps or, + // otherwise, with widest possible steps. + static bool constexpr GreedyBucketsRehashing = false; + + // Specify whether chains would be joined. It saves memory but reduces + // performance a little. + static bool constexpr JoinChains = true; +}; + +/// AllocatorRefTy type allows to get access either to: +/// +/// 1. thread local allocator. +/// +/// static LLVM_THREAD_LOCAL BumpPtrAllocator DataAllocator; +/// class AllocatorRef { +/// public: +/// static inline BumpPtrAllocator& getAllocatorRef() { +/// return DataAllocator; +/// } +/// }; +/// +/// ASSUMPTION: ThreadPool is used to keep threads alive and so to keep +/// data. +/// +/// 2. thread safe allocator. +/// +/// static ThreadSafeBumpPtrAllocator DataAllocator; +/// class AllocatorRef { +/// public: +/// static inline BumpPtrAllocator& getAllocatorRef() { +/// return DataAllocator; +/// } +/// }; +/// + +template , + typename Constants = ConcurrentHashTableConstants> +class ConcurrentHashTableByPtr { +public: + static_assert((Constants::MutexesInitialSize > 0), + "MutexesInitialSize must be greater than 0."); + static_assert((Constants::EntriesPerChainItem > 0), + "EntriesPerChainItem must be greater than 0."); + + ConcurrentHashTableByPtr( + size_t EstimatedSize = 100000, + size_t ThreadsNum = parallel::strategy.compute_thread_count()) { + assert(ThreadsNum > 0); + + // Calculate hashtable size. + EstimatedSize /= 2; + EstimatedSize /= Constants::EntriesPerChainItem; + HashTableSize = std::max(EstimatedSize, (size_t)16); + // Make hash table to be memory page size aligned. + HashTableSize = alignTo(HashTableSize, 4096 / sizeof(HashTableEntry)); + // Make size to be power of 2. + HashTableSize = PowerOf2Ceil(HashTableSize); + + // Calculate number of mutexes. + if (ThreadsNum == 1) + BucketMutexesNum = 1; + else + BucketMutexesNum = + NextPowerOf2(Constants::MutexesInitialSize * ThreadsNum); + HashTableSize = std::max(HashTableSize, BucketMutexesNum); + + // Allocate first hashtable. + NumberOfHashTables = 0; + allocateNewHashTables(1); + + BucketWidths.resize(HashTableSize); + + // Allocate mutexes. + BucketMutexes = + AllocatorRefTy::getAllocatorRef().template Allocate( + BucketMutexesNum); + for (size_t Idx = 0; Idx < BucketMutexesNum; Idx++) + new (&BucketMutexes[Idx]) std::mutex(); + + // Calculate masks. + unsigned LeadingZerosNumber = countLeadingZeros(getHashMask()); + HashBitsNum = sizeof(hash_code) * 8 - LeadingZerosNumber; + + // We keep only high 32-bits of hash value. So bucket width cannot + // exceed 2^32. Bucket width is always power of two. + MaxBucketWidth = 1Ull << (std::min(32U, LeadingZerosNumber)); + + // Calculate mask for extended hash bits. + ExtHashMask = ((HashTableSize * MaxBucketWidth) - 1) & ~getHashMask(); + } + + /// Insert new value \p NewValue or return already existing entry. + /// + /// \returns entry and "true" if an entry is just inserted or + /// "false" if an entry already exists. + std::pair insert(const KeyTy &NewValue) { + // Calculate bucket index. + hash_code Hash = Info::getHashValue(NewValue); + size_t BucketIdx = Hash & getHashMask(); + uint32_t ExtHashBits = getExtHashBits(Hash); + std::mutex &BucketMutex = getBucketMutex(BucketIdx); + + // Lock bucket. + BucketMutex.lock(); + + size_t BucketWidth = getBucketWidth(BucketIdx); + + while (true) { + // Get chain head. + size_t ChainIdx = getChainIdx(ExtHashBits, BucketWidth); + std::pair ChainHead = + getOrCreateChainHead(BucketIdx, ChainIdx, BucketWidth); + + // For each chain entry... + uint32_t *Hashes = ChainHead.first->ExtHashBits; + KeyDataTy **Data = ChainHead.first->Data; + for (size_t EntryIdx = 0; EntryIdx < Constants::EntriesPerChainItem; + EntryIdx++) { + if (Data[EntryIdx] == nullptr) { + // found empty slot. Insert data. + KeyDataTy *NewData = + Info::create(NewValue, AllocatorRefTy::getAllocatorRef()); + Data[EntryIdx] = NewData; + + if (EntryIdx + 1 < Constants::EntriesPerChainItem) + Data[EntryIdx + 1] = nullptr; + + Hashes[EntryIdx] = ExtHashBits; + + BucketMutex.unlock(); + + return {NewData, true}; + } + + if (Hashes[EntryIdx] == ExtHashBits) { + // Hash matched. Check value for equality. + KeyDataTy *FoundData = Data[EntryIdx]; + if (Info::isEqual(Info::getKey(*FoundData), NewValue)) { + // Already existed entry matched with inserted data is found. + BucketMutex.unlock(); + + return {FoundData, false}; + } + } + } + + if (ChainHead.second != Disjoined) { + // No more free slots in current chain. Split current chain, + // if this is a joined chain. + splitJoinedChain(ChainHead, BucketIdx, ChainIdx, BucketWidth); + continue; + } + + // No more free slots in current chain. Extend the bucket width. + BucketWidth = extendTable(BucketIdx, BucketWidth); + } + + llvm_unreachable("Insertion error."); + return std::pair(); + } + + /// Print information about current state of hash table structures. + void printStatistic(raw_ostream &OS) { + OS << "\n--- HashTable statistic:\n"; + OS << "\nTable Size = " << HashTableSize; + OS << "\nNumber of tables = " << NumberOfHashTables; + OS << "\nEntries per chain item = " << (int)Constants::EntriesPerChainItem; + OS << "\nNumber of mutexes = " << (int)BucketMutexesNum; + + uint64_t NumberOfBuckets = 0; + uint64_t NumberOfEntries = 0; + uint64_t NumberOfEntriesPlusEmpty = 0; + uint64_t OverallSize = + sizeof(*this) + + HashTableSize * NumberOfHashTables * sizeof(ChainItem *) + + BucketWidths.size() + BucketMutexesNum * sizeof(std::mutex); + uint64_t NumberOfChainItems = 0; + uint64_t NumberOfJoinedChains = 0; + + DenseMap BucketWidthsMap; + + // For each bucket... + for (uint64_t CurBucketIdx = 0; CurBucketIdx < HashTableSize; + CurBucketIdx++) { + + size_t BucketWidth = getBucketWidth(CurBucketIdx); + BucketWidthsMap[BucketWidth]++; + + bool IsEmptyBucket = true; + + // For each chain... + for (size_t CurChainIdx = 0; CurChainIdx < BucketWidth; CurChainIdx++) { + + std::pair ChainHead = + getChainHead(CurBucketIdx, CurChainIdx); + if (ChainHead.first == nullptr) + continue; + + NumberOfChainItems++; + if (ChainHead.second != Disjoined) + NumberOfJoinedChains++; + + // Calculate data from joined chains only once. + if (ChainHead.second != Disjoined && + (CurChainIdx != getBaseChainIdxForJoinedChain(CurChainIdx))) + continue; + + IsEmptyBucket = false; + + // For each chain entry... + size_t EntryIdx = 0; + for (; EntryIdx < Constants::EntriesPerChainItem; EntryIdx++) { + if (ChainHead.first->Data[EntryIdx] == nullptr) + break; + } + + NumberOfEntries += EntryIdx; + NumberOfEntriesPlusEmpty += Constants::EntriesPerChainItem; + + OverallSize += sizeof(ChainItem); + } + + if (!IsEmptyBucket) + NumberOfBuckets++; + } + + OS << "\nOverall number of entries = " << NumberOfEntries; + OS << "\nOverall number of chains = " << NumberOfChainItems; + OS << "\nOverall number of joined chains = " << NumberOfJoinedChains; + OS << "\nOverall number of buckets = " << NumberOfBuckets; + for (auto &Width : BucketWidthsMap) + OS << "\n Number of buckets with width " << Width.first << ": " + << Width.second; + + std::stringstream stream; + stream << std::fixed << std::setprecision(2) + << ((float)NumberOfEntries / (float)NumberOfEntriesPlusEmpty); + std::string str = stream.str(); + + OS << "\nLoad factor = " << str; + OS << "\nOverall allocated size = " << OverallSize; + } + + size_t getMaxBucketWidth() const { return NumberOfHashTables; } + +protected: + struct ChainItem { + ChainItem() = delete; + ChainItem(const ChainItem &) = delete; + ChainItem &operator=(const ChainItem &) = delete; + + KeyDataTy *Data[Constants::EntriesPerChainItem]; + uint32_t ExtHashBits[Constants::EntriesPerChainItem]; + }; + + struct DataWithExtHashBits { + uint32_t ExtHashBits = 0; + KeyDataTy *Data = nullptr; + }; + + using HashTableEntry = ChainItem *; + using HashTablePtr = HashTableEntry *; + + // Adjacent chains could share single body. The JoinedChainState keeps + // state for chain(whether it is joined or not). + enum JoinedChainState : uint8_t { + Disjoined = 0, + Joined = 1, + }; + + // Pointers to chain head keep joined state. Thus it should reserve enought + // bits to keep JoinedChainState values. + using JoinBitsEntry = PointerIntPair; + + // Return chain head. + std::pair getChainHead(size_t BucketIdx, + size_t ChainIdx) { + ChainItem *ChainHead = HashTablesSet[ChainIdx][BucketIdx]; + + JoinBitsEntry JoinedHead = JoinBitsEntry::getFromOpaqueValue(ChainHead); + return std::make_pair(JoinedHead.getPointer(), JoinedHead.getInt()); + } + + // Return chain head. Create one if neccessary. + std::pair + getOrCreateChainHead(size_t BucketIdx, size_t ChainIdx, size_t BucketWidth) { + ChainItem *ChainHead = HashTablesSet[ChainIdx][BucketIdx]; + JoinBitsEntry JoinedHead = JoinBitsEntry::getFromOpaqueValue(ChainHead); + + if (ChainHead == nullptr) { + if constexpr (!Constants::JoinChains) { + // Do not create joined chains, if Constants::JoinChains is false. + // Just allocate a new chain. + ChainHead = allocateItem(); + JoinedHead.setInt(Disjoined); + JoinedHead.setPointer(ChainHead); + HashTablesSet[ChainIdx][BucketIdx] = + static_cast(JoinedHead.getOpaqueValue()); + return {ChainHead, Disjoined}; + } else { + // Not possible to create joined chains if BucketWidth == 1. + if (BucketWidth == 1) { + ChainHead = allocateItem(); + JoinedHead.setInt(Disjoined); + JoinedHead.setPointer(ChainHead); + assert(ChainIdx == 0); + HashTablesSet[ChainIdx][BucketIdx] = + static_cast(JoinedHead.getOpaqueValue()); + return {ChainHead, Disjoined}; + } + + if (ChainIdx != getBaseChainIdxForJoinedChain(ChainIdx)) + ChainHead = + HashTablesSet[getBaseChainIdxForJoinedChain(ChainIdx)][BucketIdx]; + + if (ChainHead == nullptr) { + // Allocate base chain if it is not existed. + ChainHead = allocateItem(); + JoinedHead.setPointer(ChainHead); + } else { + // Get chain body from the base chain. + JoinedHead = JoinBitsEntry::getFromOpaqueValue(ChainHead); + assert(JoinedHead.getInt() == Disjoined); + } + // Mark chain as joined. + JoinedHead.setInt(Joined); + + // Store the same body to the base and paired chain. + assert(getBaseChainIdxForJoinedChain(ChainIdx) < BucketWidth); + assert(getPairedChainIdxForJoinedChain(ChainIdx) < BucketWidth); + HashTablesSet[getBaseChainIdxForJoinedChain(ChainIdx)][BucketIdx] = + static_cast(JoinedHead.getOpaqueValue()); + HashTablesSet[getPairedChainIdxForJoinedChain(ChainIdx)][BucketIdx] = + static_cast(JoinedHead.getOpaqueValue()); + + return {ChainHead, Joined}; + } + } + + return {JoinedHead.getPointer(), JoinedHead.getInt()}; + } + + bool isEmptyBucket(size_t BucketIdx) { + size_t BucketWidth = getBucketWidth(BucketIdx); + + for (size_t CurChainIdx = 0; CurChainIdx < BucketWidth; CurChainIdx++) { + if (getChainHead(BucketIdx, CurChainIdx).first != nullptr) + return false; + } + + return true; + } + + size_t getBucketWidth(size_t BucketIdx) { + assert(BucketIdx < HashTableSize); + return 1 << BucketWidths[BucketIdx]; + } + + void setBucketWidth(size_t BucketIdx, size_t NewBucketWidth) { + assert((NewBucketWidth >= 1) & !(NewBucketWidth & (NewBucketWidth - 1))); + assert(getBucketWidth(BucketIdx) < NewBucketWidth); + + BucketWidths[BucketIdx] = countTrailingZeros(NewBucketWidth); + } + + uint32_t getExtHashBits(hash_code Hash) { + return (Hash & ExtHashMask) >> HashBitsNum; + } + + size_t getChainIdx(uint32_t ExtHashBits, size_t BucketWidth) { + assert(BucketWidth > 0); + + return ExtHashBits & (BucketWidth - 1); + } + + // Allocate new chain intem. + LLVM_ATTRIBUTE_ALWAYS_INLINE ChainItem *allocateItem() { + ChainItem *ResultItem = + AllocatorRefTy::getAllocatorRef().template Allocate(); + + assert(ResultItem != nullptr); + ResultItem->Data[0] = nullptr; + + return ResultItem; + } + + void splitJoinedChainToSpecifiedPair(ChainItem *BaseChainHead, + size_t BaseChainIdx, + ChainItem *NewChainHead, + size_t NewChainIdx, size_t BucketWidth) { + uint32_t *Hashes = BaseChainHead->ExtHashBits; + uint32_t *NewHashes = NewChainHead->ExtHashBits; + + KeyDataTy **Data = BaseChainHead->Data; + KeyDataTy **NewData = NewChainHead->Data; + + size_t BaseDstIdx = 0; + size_t NewDstIdx = 0; + for (size_t BaseSrcIdx = 0; BaseSrcIdx < Constants::EntriesPerChainItem; + BaseSrcIdx++) { + size_t RehashedChainIdx = getChainIdx(Hashes[BaseSrcIdx], BucketWidth); + assert(RehashedChainIdx == BaseChainIdx || + RehashedChainIdx == NewChainIdx); + + if (BaseChainIdx == RehashedChainIdx) { + Hashes[BaseDstIdx] = Hashes[BaseSrcIdx]; + Data[BaseDstIdx] = Data[BaseSrcIdx]; + BaseDstIdx++; + } else { + NewHashes[NewDstIdx] = Hashes[BaseSrcIdx]; + NewData[NewDstIdx] = Data[BaseSrcIdx]; + NewDstIdx++; + } + } + Data[BaseDstIdx] = nullptr; + NewData[NewDstIdx] = nullptr; + + assert((BaseDstIdx + NewDstIdx > 0) && + "Resulted chains could not be empty."); + } + + void splitJoinedChain(std::pair ChainHead, + size_t BucketIdx, size_t ChainIdx, size_t BucketWidth) { + assert(ChainIdx < BucketWidth); + assert(ChainHead.second != Disjoined); + + ChainItem *BaseChainHead = ChainHead.first; + size_t BaseChainIdx = getBaseChainIdxForJoinedChain(ChainIdx); + ChainItem *NewChainHead = allocateItem(); + size_t NewChainIdx = getPairedChainIdxForJoinedChain(ChainIdx); + + splitJoinedChainToSpecifiedPair(BaseChainHead, BaseChainIdx, NewChainHead, + NewChainIdx, BucketWidth); + + JoinBitsEntry JoinedHead; + JoinedHead.setInt(Disjoined); + JoinedHead.setPointer(BaseChainHead); + HashTablesSet[BaseChainIdx][BucketIdx] = + static_cast(JoinedHead.getOpaqueValue()); + + JoinedHead.setPointer(NewChainHead); + HashTablesSet[NewChainIdx][BucketIdx] = + static_cast(JoinedHead.getOpaqueValue()); + } + + // Extend HashTables and rehash specified bucket. + LLVM_ATTRIBUTE_ALWAYS_INLINE size_t extendTable(size_t BucketIdx, + size_t BucketWidth) { + assert(BucketWidth == getBucketWidth(BucketIdx)); + + if (BucketWidth >= MaxBucketWidth) + report_fatal_error("ConcurrentHashTable is full"); + + return ExtendAndRehashBucket(BucketIdx, BucketWidth); + } + + using CollectedValuesVector = SmallVector; + + // This function adds new chains to the bucket and rehash existing + // values to place them into the new chains. + size_t ExtendAndRehashBucket(size_t BucketIdx, size_t BucketWidth) { + assert(BucketWidth < MaxBucketWidth); + assert(BucketWidth > 0); + + size_t NewBucketsWidth = BucketWidth << 1; + if constexpr (Constants::GreedyBucketsRehashing) + NewBucketsWidth = std::max((size_t)NumberOfHashTables, NewBucketsWidth); + assert(NewBucketsWidth <= MaxBucketWidth); + + // Check whether we need to create new hashtables. + if (NewBucketsWidth > NumberOfHashTables) { + std::unique_lock Guard(HashTableMutex); + + if (NewBucketsWidth > NumberOfHashTables) + allocateNewHashTables(NewBucketsWidth); + } + + // Collect bucket values. + CollectedValuesVector Values; + Values.reserve(Constants::EntriesPerChainItem * BucketWidth); + collectValuesLocked(BucketIdx, BucketWidth, Values); + assert(Values.size() > 0); + + // Insert values into extended bucket. + insertValuesLocked(Values, BucketIdx, NewBucketsWidth); + + // Update bucket width. + setBucketWidth(BucketIdx, NewBucketsWidth); + + assert(getBucketWidth(BucketIdx) == NewBucketsWidth); + assert(!isEmptyBucket(BucketIdx)); + + return NewBucketsWidth; + } + + // Collect all data into the specified vector Values. + // Erase data from all bucket's chains. + void collectValuesLocked(size_t BucketIdx, size_t BucketWidth, + CollectedValuesVector &Values) { + // Collect bucket values. + for (size_t CurChainIdx = 0; CurChainIdx < BucketWidth; CurChainIdx++) { + + std::pair ChainHead = + getChainHead(BucketIdx, CurChainIdx); + + if (ChainHead.first == nullptr) + continue; + + assert(BucketWidth > 0 || ChainHead.second == Disjoined); + + // Copy data from joined chains only once. + if (ChainHead.second != Disjoined && + (CurChainIdx != getBaseChainIdxForJoinedChain(CurChainIdx))) + continue; + + for (size_t Idx = 0; Idx < Constants::EntriesPerChainItem; Idx++) { + KeyDataTy *Data = ChainHead.first->Data[Idx]; + // End of data marker is found. + if (Data == nullptr) + break; + + // Copy data to the Values vector. + Values.emplace_back( + DataWithExtHashBits{ChainHead.first->ExtHashBits[Idx], Data}); + + // Write end of data marker. + if (Idx == 0) + ChainHead.first->Data[Idx] = nullptr; + } + } + } + + size_t getBaseChainIdxForJoinedChain(size_t ChainIdx) { + return (ChainIdx & (~1)); + } + + size_t getPairedChainIdxForJoinedChain(size_t ChainIdx) { + return (ChainIdx | 1); + } + + size_t getJoinedPositionIdx(SmallVector &ChainInsertionPositionIdxs, + size_t ChainIdx, JoinedChainState ChainState) { + if (ChainState == Disjoined) + return ChainInsertionPositionIdxs[ChainIdx]; + + return ChainInsertionPositionIdxs[getBaseChainIdxForJoinedChain(ChainIdx)] + + ChainInsertionPositionIdxs[getPairedChainIdxForJoinedChain( + ChainIdx)]; + } + + // Insert data into the locked bucket. + void insertValuesLocked(CollectedValuesVector &Values, size_t BucketIdx, + size_t BucketWidth) { + SmallVector ChainInsertionPositionIdxs; + ChainInsertionPositionIdxs.resize(BucketWidth); + + for (const DataWithExtHashBits &NewValue : Values) { + assert(NewValue.Data != nullptr); + + size_t ChainIdx = getChainIdx(NewValue.ExtHashBits, BucketWidth); + + bool ValueInserted = false; + while (!ValueInserted) { + std::pair ChainHead = + getOrCreateChainHead(BucketIdx, ChainIdx, BucketWidth); + assert(ChainHead.first != nullptr); + + size_t LastEntryIdx = getJoinedPositionIdx(ChainInsertionPositionIdxs, + ChainIdx, ChainHead.second); + + if (LastEntryIdx >= Constants::EntriesPerChainItem) { + // No more room for the data in ChainIdx chain. Split joined chain. + + assert(ChainIdx < BucketWidth); + assert(ChainHead.second != Disjoined); + splitJoinedChain(ChainHead, BucketIdx, ChainIdx, BucketWidth); + continue; + } + + assert(LastEntryIdx < Constants::EntriesPerChainItem); + assert(ChainHead.first->Data[LastEntryIdx] == nullptr); + + // Write data to the chain. + ChainHead.first->ExtHashBits[LastEntryIdx] = NewValue.ExtHashBits; + ChainHead.first->Data[LastEntryIdx] = NewValue.Data; + ChainInsertionPositionIdxs[ChainIdx]++; + + // Write end of data marker. + if (LastEntryIdx + 1 < Constants::EntriesPerChainItem) + ChainHead.first->Data[LastEntryIdx + 1] = nullptr; + + ValueInserted = true; + } + } + } + + // Allocate new hashtables for specified NewBucketsWidth. + void allocateNewHashTables(size_t NewBucketsWidth) { + assert((NewBucketsWidth >= 1) & !(NewBucketsWidth & (NewBucketsWidth - 1))); + assert(NumberOfHashTables < NewBucketsWidth); + + AllocatorTy &Allocator = AllocatorRefTy::getAllocatorRef(); + + HashTablePtr *NewHashTablesSet = + Allocator.template Allocate(NewBucketsWidth); + // Allocate and initialize new tables, copy old tables. + // TODO: tables initialization might be done in parallel. + for (size_t Idx = 0; Idx < NewBucketsWidth; Idx++) { + if (Idx < NumberOfHashTables) { + NewHashTablesSet[Idx] = HashTablesSet[Idx]; + continue; + } + + HashTablePtr NewTable = + Allocator.template Allocate(HashTableSize); + memset(NewTable, 0, sizeof(HashTableEntry) * HashTableSize); + NewHashTablesSet[Idx] = NewTable; + } + + HashTablesSet = NewHashTablesSet; + NumberOfHashTables = NewBucketsWidth; + } + + // Return mutex for the specified BucketIdx. + std::mutex &getBucketMutex(size_t BucketIdx) { + return BucketMutexes[BucketIdx & (BucketMutexesNum - 1)]; + } + + size_t getHashMask() { return HashTableSize - 1; } + + // Number of bits in hash mask. + uint64_t HashBitsNum = 0; + + // Hash mask for the extended hash bits. + uint64_t ExtHashMask = 0; + + // Array of mutexes. + std::mutex *BucketMutexes = nullptr; + + // Number of used mutexes. + size_t BucketMutexesNum = 2; + + // The maximal bucket width. + size_t MaxBucketWidth = 0; + + // The size of single hashtable. + size_t HashTableSize = 0; + + // The guard for HashTables array. + std::mutex HashTableMutex; + + // The current number of HashTables. + std::atomic NumberOfHashTables; + + SmallVector BucketWidths; + + // HashTables keeping buckets chains. + std::atomic HashTablesSet = nullptr; +}; + +} // end namespace llvm + +#endif // LLVM_ADT_CONCURRENTHASHTABLE_H diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -17,6 +17,7 @@ BumpPtrListTest.cpp CoalescingBitVectorTest.cpp CombinationGeneratorTest.cpp + ConcurrentHashtableTest.cpp DAGDeltaAlgorithmTest.cpp DeltaAlgorithmTest.cpp DenseMapTest.cpp diff --git a/llvm/unittests/ADT/ConcurrentHashtableTest.cpp b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp @@ -0,0 +1,347 @@ +//===- ConcurrentHashtableTest.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ConcurrentHashtable.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Parallel.h" +#include "gtest/gtest.h" +#include +#include +#include +using namespace llvm; + +namespace { +using ExtraDataTy = std::array; + +class String { +public: + String() {} + const std::string &getKey() const { return Data; } + + static String *create(const std::string &Num, BumpPtrAllocator &Allocator) { + String *Result = Allocator.Allocate(); + new (Result) String(Num); + return Result; + } + +protected: + String(const std::string &Num) { Data += Num; } + + std::string Data; + ExtraDataTy ExtraData; +}; + +class HashTableConstants { +public: + // Define the number of entries per chain item. + static size_t constexpr EntriesPerChainItem = 32; + + // Define the number of mutexes. + static size_t constexpr MutexesInitialSize = 2; + + // Specify whether buckets should be rehashed with small steps or, + // otherwise, with widest possible steps. + static bool constexpr GreedyBucketsRehashing = false; + + // Specify whether chains would be joined. It saves memory but reduces + // performance a little. + static bool constexpr JoinChains = true; +}; + +class HashTableConstantsUseGreedyRehashing { +public: + // Define the number of entries per chain item. + static size_t constexpr EntriesPerChainItem = 32; + + // Define the number of mutexes. + static size_t constexpr MutexesInitialSize = 2; + + // Specify whether buckets should be rehashed with small steps or, + // otherwise, with widest possible steps. + static bool constexpr GreedyBucketsRehashing = true; + + // Specify whether chains would be joined. It saves memory but reduces + // performance a little. + static bool constexpr JoinChains = true; +}; + +class HashTableConstantsDontJoinChains { +public: + // Define the number of entries per chain item. + static size_t constexpr EntriesPerChainItem = 32; + + // Define the number of mutexes. + static size_t constexpr MutexesInitialSize = 2; + + // Specify whether buckets should be rehashed with small steps or, + // otherwise, with widest possible steps. + static bool constexpr GreedyBucketsRehashing = false; + + // Specify whether chains would be joined. It saves memory but reduces + // performance a little. + static bool constexpr JoinChains = false; +}; + +template void testAddStringEntries() { + static BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTableByPtr< + std::string, String, AllocatorRef, BumpPtrAllocator, + ConcurrentHashTableInfoByPtr, + HashMapConstants> + HashTable(10); + + std::pair res1 = HashTable.insert("1"); + // Check entry is inserted. + EXPECT_TRUE(res1.first->getKey() == "1"); + EXPECT_TRUE(res1.second); + + std::pair res2 = HashTable.insert("2"); + // Check old entry is still valid. + EXPECT_TRUE(res1.first->getKey() == "1"); + // Check new entry is inserted. + EXPECT_TRUE(res2.first->getKey() == "2"); + EXPECT_TRUE(res2.second); + // Check new and old entries use different memory. + EXPECT_TRUE(res1.first != res2.first); + + std::pair res3 = HashTable.insert("3"); + // Check one more entry is inserted. + EXPECT_TRUE(res3.first->getKey() == "3"); + EXPECT_TRUE(res3.second); + + std::pair res4 = HashTable.insert("1"); + // Check duplicated entry is inserted. + EXPECT_TRUE(res4.first->getKey() == "1"); + EXPECT_FALSE(res4.second); + // Check duplicated entry uses the same memory. + EXPECT_TRUE(res1.first == res4.first); + + // Check first entry is still valid. + EXPECT_TRUE(res1.first->getKey() == "1"); + + // Check data was allocated by allocator. + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("Overall number of entries = 3") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddStringEntries) { + + SCOPED_TRACE("AddStringEntries for HashTableConstants"); + testAddStringEntries(); + + SCOPED_TRACE("AddStringEntries for HashTableConstantsUseGreedyRehashing"); + testAddStringEntries(); + + SCOPED_TRACE("AddStringEntries for HashTableConstantsDontJoinChains"); + testAddStringEntries(); +} + +template void testAddStringMultiplueEntries() { + const size_t NumElements = 10000; + static LLVM_THREAD_LOCAL BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTableByPtr< + std::string, String, AllocatorRef, BumpPtrAllocator, + ConcurrentHashTableInfoByPtr, + HashMapConstants> + HashTable(NumElements); + + // Check parallel insertion. + for (size_t I = 0; I < NumElements; I++) { + std::string StringForElement = formatv("{0}", I); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first->getKey() == StringForElement); + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + } + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + // Check parallel insertion of duplicates. + for (size_t I = 0; I < NumElements; I++) { + size_t BytesAllocated = Allocator.getBytesAllocated(); + std::string StringForElement = formatv("{0}", I); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first->getKey() == StringForElement); + // Check no additional bytes were allocated for duplicate. + EXPECT_TRUE(Allocator.getBytesAllocated() == BytesAllocated); + } + + // Check statistic. + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddStringMultiplueEntries) { + SCOPED_TRACE("AddStringMultiplueEntries for HashTableConstants"); + testAddStringMultiplueEntries(); + + SCOPED_TRACE( + "AddStringMultiplueEntries for HashTableConstantsUseGreedyRehashing"); + testAddStringMultiplueEntries(); + + SCOPED_TRACE( + "AddStringMultiplueEntries for HashTableConstantsDontJoinChains"); + testAddStringMultiplueEntries(); +} + +template +void testAddStringMultiplueEntriesWithResize() { + // Number of elements exceeds original size, thus hashtable should be resized. + const size_t NumElements = 20000; + static LLVM_THREAD_LOCAL BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTableByPtr< + std::string, String, AllocatorRef, BumpPtrAllocator, + ConcurrentHashTableInfoByPtr, + HashMapConstants> + HashTable(100); + + // Check insertion. + for (size_t I = 0; I < NumElements; I++) { + std::string StringForElement = formatv("{0} {1}", I, I + 100); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first->getKey() == StringForElement); + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + } + + // Check that table was resized. + EXPECT_TRUE(HashTable.getMaxBucketWidth() > 1); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000") != + std::string::npos); + + // Check insertion of duplicates. + for (size_t I = 0; I < NumElements; I++) { + size_t BytesAllocated = Allocator.getBytesAllocated(); + std::string StringForElement = formatv("{0} {1}", I, I + 100); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first->getKey() == StringForElement); + // Check no additional bytes were allocated for duplicate. + EXPECT_TRUE(Allocator.getBytesAllocated() == BytesAllocated); + } + + // Check statistic. + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddStringMultiplueEntriesWithResize) { + SCOPED_TRACE("AddStringMultiplueEntriesWithResize for HashTableConstants"); + testAddStringMultiplueEntriesWithResize(); + + SCOPED_TRACE( + "AddStringMultiplueEntries for HashTableConstantsUseGreedyRehashing"); + testAddStringMultiplueEntriesWithResize< + HashTableConstantsUseGreedyRehashing>(); + + SCOPED_TRACE( + "AddStringMultiplueEntries for HashTableConstantsDontJoinChains"); + testAddStringMultiplueEntriesWithResize(); +} + +template void testAddStringEntriesParallel() { + const size_t NumElements = 10000; + static LLVM_THREAD_LOCAL BumpPtrAllocator Allocator; + class AllocatorRef { + public: + static inline BumpPtrAllocator &getAllocatorRef() { return Allocator; } + }; + ConcurrentHashTableByPtr< + std::string, String, AllocatorRef, BumpPtrAllocator, + ConcurrentHashTableInfoByPtr, + HashMapConstants> + HashTable; + + // Check parallel insertion. + parallelFor(0, NumElements, [&](size_t I) { + std::string StringForElement = formatv("{0}", I); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first->getKey() == StringForElement); + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + }); + + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); + + // Check parallel insertion of duplicates. + parallelFor(0, NumElements, [&](size_t I) { + size_t BytesAllocated = Allocator.getBytesAllocated(); + std::string StringForElement = formatv("{0}", I); + std::pair Entry = HashTable.insert(StringForElement); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first->getKey() == StringForElement); + // Check no additional bytes were allocated for duplicate. + EXPECT_TRUE(Allocator.getBytesAllocated() == BytesAllocated); + }); + + // Check statistic. + // Verifying that the table contains exactly the number of elements we + // inserted. + EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000") != + std::string::npos); +} + +TEST(ConcurrentHashTableTest, AddStringEntriesParallel) { + SCOPED_TRACE("AddStringEntriesParallel for HashTableConstants"); + testAddStringEntriesParallel(); + + SCOPED_TRACE( + "AddStringEntriesParallel for HashTableConstantsUseGreedyRehashing"); + testAddStringEntriesParallel(); + + SCOPED_TRACE("AddStringEntriesParallel for HashTableConstantsDontJoinChains"); + testAddStringEntriesParallel(); +} + +} // namespace