diff --git a/llvm/include/llvm/ADT/LockFreeDataPool.h b/llvm/include/llvm/ADT/LockFreeDataPool.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/LockFreeDataPool.h @@ -0,0 +1,242 @@ +//===- LockFreeDataPool.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_LOCKFREEDATAPOOL_H +#define LLVM_ADT_LOCKFREEDATAPOOL_H + +#include "llvm/ADT/Hashing.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/NativeFormatting.h" +#include "llvm/Support/WithColor.h" +#include +#include + +namespace llvm { + +/// LockFreeDataPool - is a data pool that keeps data identified by the key. +/// It uses a partially resizeable lock-free hashtable for mapping keys to the +/// data. Insertions into the internal hash table are lock-free, and +/// allocations are done using an external non-mt-safe data pool. +/// +/// The implementation of hash-table is copied(with modifications) +/// from lld/COFF/DebugTypes.cpp +/// +/// A concurrent hash table for global type hashing. It is based on this paper: +/// Concurrent Hash Tables: Fast and General(?)! +/// https://dl.acm.org/doi/10.1145/3309206 +/// +/// This hash table is meant to be used in two phases: +/// 1. concurrent insertions +/// 2. concurrent reads +/// It does not support lookup, deletion, or rehashing. It uses linear probing. +/// +/// The paper describes storing a key-value pair in two machine words. +/// This hash table keeps a pointer to the KeyDataTy class, whose data are +/// allocated and kept in the external allocator. + +/// For KeyTy there should be implemented: hash_code hash_value(KeyTy Key) +/// +/// For KeyDataTy there should be implemented: +/// KeyTy KeyDataTy::key() const; +/// static KeyDataTy* KeyDataTy::create (KeyTy Key, AllocatorTy& Allocator); + +template +class LockFreeDataPool { +public: + LockFreeDataPool( + std::function WarningHandler = nullptr) + : WarningHandler(WarningHandler) {} + + /// Initialize the table with the given size. + void init(AllocatorTy &Allocator, size_t InitialSize) { + TableSize = NextPowerOf2(InitialSize); + TableSize = std::max(TableSize, (size_t)4096); + HashMask = TableSize - 1; + CurSlabNum = 0; + FirstSlab = allocateNextSlab(Allocator); + } + + /// Insert new entry allocated by \p Allocator for specified \p Key or + /// return already existing entry which key matched with specified \p Key. + /// + /// \returns pool entry and "true" - if an entry is just inserted or + /// "false" if an entry already exists. + std::pair insert(AllocatorTy &Allocator, KeyTy Key) { + return insert(Allocator, FirstSlab, hash_value(Key) & HashMask, Key, + nullptr); + } + + /// Print information about current state of pool structures. + void printStatistic(raw_ostream &OS) { + OS << "\nLockFreeDataPool:\n"; + OS << "\nNumber of slabs: " << CurSlabNum; + OS << "\nNumber of elements per slab: " << TableSize; + OS << "\nFull size of pool: " + << CurSlabNum * TableSize * sizeof(KeyDataTy *) << " bytes \n"; + + size_t OverallNumberOfUsedEntries = 0; + for (Slab *CurSlab = FirstSlab; CurSlab != nullptr; + CurSlab = CurSlab->NextSlab) { + std::pair Statistic = getNumberOfEntriesForSlab(CurSlab); + OS << "\n Slab depth: " << CurSlab->SlabDepth; + OS << "\n Slab fullness: "; + llvm::write_double(OS, ((float)Statistic.first / (float)TableSize), + FloatStyle::Percent); + OS << "\n Slab longest bucket: " << Statistic.second; + OS << "\n"; + + OverallNumberOfUsedEntries += Statistic.first; + } + + OS << "\nOverall number of used elements: " << OverallNumberOfUsedEntries; + OS << "\nOverall fullness of pool: "; + llvm::write_double( + OS, + ((float)OverallNumberOfUsedEntries / ((float)CurSlabNum * TableSize)), + FloatStyle::Percent); + } + +protected: + struct Slab { + std::atomic *Table = nullptr; + std::atomic NextSlab = nullptr; + size_t SlabDepth = 0; + }; + + // \p returns next available slab, allocate a new one if necessary. + Slab *getNextSlab(Slab *CurSlab, AllocatorTy &Allocator) { + if (CurSlab->NextSlab == nullptr) { + const std::lock_guard lock(SlabsMutex); + + if (CurSlab->NextSlab == nullptr) + CurSlab->NextSlab = allocateNextSlab(Allocator); + } + + return CurSlab->NextSlab; + } + + // Allocate new slab. + Slab *allocateNextSlab(AllocatorTy &Allocator) { + Slab *NewSlab = Allocator.template Allocate(); + + NewSlab->Table = + Allocator.template Allocate>(TableSize); + memset(NewSlab->Table, 0, TableSize * sizeof(KeyDataTy *)); + NewSlab->NextSlab = nullptr; + + NewSlab->SlabDepth = CurSlabNum++; + + if (CurSlabNum > MaxNumberOfSlabs) { + if (CurSlabNum > 2 * MaxNumberOfSlabs) + report_fatal_error("Hash table is full."); + else + reportWarning("Too many slabs. Consider increasing initial size or " + "changing hash function."); + } + + return NewSlab; + } + + // Insert \p NewData into the specified \p CurSlab. Start probing from + // the \p StartIdx. Allocate new slab if specified slab \p CurSlab is full. + std::pair insert(AllocatorTy &Allocator, Slab *CurSlab, + uint32_t StartIdx, KeyTy Key, + KeyDataTy *NewData) { + // Do a linear probe starting at startIdx. + uint32_t Idx = StartIdx; + uint32_t CurBucketSize = 0; + + while (true) { + // Run a compare and swap loop. There are three cases: + // - cell is empty: CAS into place and return + // - cell has matching key, return cell`s data + // - cell has non-matching key: hash collision, probe next cell + KeyDataTy *Candidate = CurSlab->Table[Idx]; + if (Candidate == nullptr) { + if (CurBucketSize >= MaxBucketLength) + // insert data into the new slab. + return insert(Allocator, getNextSlab(CurSlab, Allocator), StartIdx, + Key, NewData); + + if (NewData == nullptr) + NewData = KeyDataTy::create(Key, Allocator); + + if (CurSlab->Table[Idx].compare_exchange_weak(Candidate, NewData)) + return std::make_pair(NewData, true); + continue; + } + + if (Candidate->key() == Key) + return std::make_pair(Candidate, false); + + ++CurBucketSize; + // Advance the probe. Wrap around to the beginning if we run off the end. + ++Idx; + + Idx = Idx == TableSize ? 0 : Idx; + if (Idx == StartIdx) { + reportWarning("Slab is 100% full. Consider increasing initial size."); + + return insert(Allocator, getNextSlab(CurSlab, Allocator), StartIdx, Key, + NewData); + } + } + llvm_unreachable("left infloop"); + } + + // \returns number of used elements inside slab and the length of the longest + // bucket inside slab. + std::pair getNumberOfEntriesForSlab(Slab *CurSlab) { + size_t NumberOfElements = 0; + size_t MaxBucketLength = 0; + size_t CurBucketLength = 0; + bool prevEntryIsNull = true; + for (size_t CurIdx = 0; CurIdx < TableSize; CurIdx++) { + if (CurSlab->Table[CurIdx] != nullptr) { + NumberOfElements++; + + if (prevEntryIsNull) + CurBucketLength = 0; + else + CurBucketLength++; + prevEntryIsNull = false; + } else { + if (!prevEntryIsNull) + MaxBucketLength = std::max(MaxBucketLength, CurBucketLength); + prevEntryIsNull = true; + } + } + + return std::make_pair(NumberOfElements, MaxBucketLength); + } + + // Call warning handler. + void reportWarning(const char *message) { + if (!WarningIsReported) { + WarningIsReported = true; + if (WarningHandler) + WarningHandler(message); + } + } + + size_t TableSize = 0; + size_t HashMask = 0; + const size_t MaxBucketLength = 100; + const size_t MaxNumberOfSlabs = 3; + + std::mutex SlabsMutex; + Slab *FirstSlab = nullptr; + size_t CurSlabNum = 0; + + std::function WarningHandler; + bool WarningIsReported = false; +}; + +} // end namespace llvm + +#endif // LLVM_ADT_LOCKFREEDATAPOOL_H diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -43,6 +43,7 @@ IntervalMapTest.cpp IntrusiveRefCntPtrTest.cpp IteratorTest.cpp + LockFreeDataPoolTest.cpp MappedIteratorTest.cpp MapVectorTest.cpp OptionalTest.cpp diff --git a/llvm/unittests/ADT/LockFreeDataPoolTest.cpp b/llvm/unittests/ADT/LockFreeDataPoolTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/LockFreeDataPoolTest.cpp @@ -0,0 +1,130 @@ +//===- LockFreeDataPoolTest.cpp - LockFreeDataPool unit tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/LockFreeDataPool.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Parallel.h" +#include "gtest/gtest.h" +#include +#include +using namespace llvm; + +namespace { + +class StringData { +public: + static StringData *create(StringRef Key, BumpPtrAllocator &Allocator) { + return new (Allocator) StringData(Key); + } + + StringRef key() { return Data; } + +protected: + StringData(StringRef Data) : Data(Data) {} + std::string Data; +}; + +TEST(LockFreeDataPoolTest, AddEntries) { + BumpPtrAllocator Allocator; + LockFreeDataPool DataPool; + + DataPool.init(Allocator, 10); + + std::pair res1 = DataPool.insert(Allocator, "1"); + // Check entry is inserted. + EXPECT_TRUE(res1.first->key() == "1"); + EXPECT_TRUE(res1.second); + + std::pair res2 = DataPool.insert(Allocator, "2"); + // Check old entry is still valid. + EXPECT_TRUE(res1.first->key() == "1"); + // Check new entry is inserted. + EXPECT_TRUE(res2.first->key() == "2"); + EXPECT_TRUE(res2.second); + // Check new and old entries use different memory. + EXPECT_TRUE(res1.first != res2.first); + + std::pair res3 = DataPool.insert(Allocator, "3"); + // Check one more entry is inserted. + EXPECT_TRUE(res3.first->key() == "3"); + EXPECT_TRUE(res3.second); + + std::pair res4 = DataPool.insert(Allocator, "1"); + // Check duplicated entry is inserted. + EXPECT_TRUE(res4.first->key() == "1"); + EXPECT_FALSE(res4.second); + // Check duplicated entry matches with the first one. + EXPECT_TRUE(res1.first == res4.first); + + // Check first entry is still valid. + EXPECT_TRUE(res1.first->key() == "1"); + + // Check data was allocated by allocator. + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + DataPool.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("LockFreeDataPool:") != std::string::npos); + EXPECT_TRUE(StatisticString.find("Number of slabs: 1") != std::string::npos); + EXPECT_TRUE(StatisticString.find("Slab depth: 0") != std::string::npos); + EXPECT_TRUE(StatisticString.find("Slab fullness: 0.07%") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("Overall number of used elements: 3") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("Overall fullness of pool: 0.07%") != + std::string::npos); +} + +TEST(LockFreeDataPoolTest, AddEntriesParallel) { + BumpPtrAllocator Allocator; + LockFreeDataPool DataPool; + + DataPool.init(Allocator, 5000); + + // Number of elements exceeds original size, thus pool`s hashtable + // should be resized. + const size_t NumElements = 10000; + BumpPtrAllocator Allocators[NumElements]; + + // Check parallel insertion. + parallelForEachN(0, NumElements, [&](size_t I) { + std::string StringForElement = formatv("{0}", I); + std::pair Entry = + DataPool.insert(Allocators[I], StringForElement); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first->key() == StringForElement); + EXPECT_TRUE(Allocators[I].getBytesAllocated() > 0); + }); + + // Check parallel insertion of duplicates. + parallelForEachN(0, NumElements, [&](size_t I) { + size_t BytesAllocated = Allocators[I].getBytesAllocated(); + std::string StringForElement = formatv("{0}", I); + std::pair Entry = + DataPool.insert(Allocators[I], StringForElement); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first->key() == StringForElement); + // Check no additional bytes were allocated for duplicate. + EXPECT_TRUE(Allocators[I].getBytesAllocated() == BytesAllocated); + }); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + DataPool.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("Number of slabs: 2") != std::string::npos); + EXPECT_TRUE(StatisticString.find("Overall number of used elements: 10000") != + std::string::npos); +} + +} // end anonymous namespace