diff --git a/llvm/include/llvm/ADT/FixedConcurrentHashTable.h b/llvm/include/llvm/ADT/FixedConcurrentHashTable.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/FixedConcurrentHashTable.h @@ -0,0 +1,323 @@ +//===- FixedConcurrentHashTable.h -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_FIXEDCONCURRENTHASHTABLE_H +#define LLVM_ADT_FIXEDCONCURRENTHASHTABLE_H + +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/NativeFormatting.h" +#include "llvm/Support/WithColor.h" +#include + +namespace llvm { + +/// FixedConcurrentHashTable - is a non resizeable concurrent lock-free +/// hashtable. +/// +/// It is based on this paper: +/// Concurrent Hash Tables: Fast and General(?)! +/// https://dl.acm.org/doi/10.1145/3309206 +/// +/// This hash table is meant to be used in two phases: +/// 1. concurrent insertions +/// 2. concurrent reads +/// It does not support lookup, deletion, or rehashing. It uses linear probing. +/// +/// The paper describes storing a key-value pair in two machine words. +/// +/// FixedConcurrentHashTable: keeps a KeyDataTy(which is a key-value pair). +/// The state when all data bits of KeyDataTy are zero is reserved as +/// a hashtable tombstone value. +/// +/// FixedConcurrentHashTableExternalAlloc: keeps a pointer to the KeyDataTy +/// class (which is a key-value pair), whose data are allocated and kept by the +/// external allocator. + +template +class FixedConcurrentHashTableBase { +public: + using MapEntryTy = typename std::conditional::type; + using InsertionTy = + typename std::conditional::type; + + static_assert((sizeof(MapEntryTy) == 1 || sizeof(MapEntryTy) == 2 || + sizeof(MapEntryTy) == 4 || sizeof(MapEntryTy) == 8), + "sizeof(MapEntryTy) should be of 1,2,4 or 8"); + static_assert((sizeof(MapEntryTy) == std::alignment_of::value), + "MapEntryTy should be naturally aligned"); + + FixedConcurrentHashTableBase() = delete; + FixedConcurrentHashTableBase(size_t InitialSize, AllocatorTy &Allocator) + : Allocator(Allocator) { + /// Initialize the table with the given size. + TableSize = NextPowerOf2(InitialSize); + HashTable = Allocator.template Allocate(TableSize); + memset(HashTable, 0, TableSize * sizeof(MapEntryTy)); + } + ~FixedConcurrentHashTableBase() { + Allocator.template Deallocate(HashTable); + } + + /// Print information about current state of hash table structures. + void printStatistic(raw_ostream &OS) { + OS << "\n--- HashTable statistic:\n"; + OS << "\n HashTable size: " << TableSize; + size_t NumberOfEntries = getNumberOfEntriesForTable(); + OS << "\n Number of entries: " << NumberOfEntries; + OS << "\n Load factor: " + << formatv("{0}", static_cast(NumberOfEntries) / + static_cast(TableSize)); + + OS << "\n Overall size(in bytes): "; + if constexpr (KeepDataInsideTable) + OS << ((TableSize * sizeof(MapEntryTy)) + sizeof(*this)); + else + OS << (TableSize * sizeof(MapEntryTy) + sizeof(*this) + + NumberOfEntries * sizeof(KeyDataTy)); + + OS << "\n\n"; + } + +protected: + // Returns number of used elements inside hashtable. + size_t getNumberOfEntriesForTable() { + size_t NumberOfEntries = 0; + for (size_t CurIdx = 0; CurIdx < TableSize; CurIdx++) { + if (!isNull(HashTable[CurIdx])) + NumberOfEntries++; + } + + return NumberOfEntries; + } + + template + std::pair insertImpl(InsertionTy NewData, + Types &... Alloc) { + // Do a linear probe starting at startIdx. + uint32_t StartIdx = getHashCode(NewData) & getHashMask(); + uint32_t Idx = StartIdx; + MapEntryTy DataToInsert; + setToNull(&DataToInsert); + + while (true) { + // Run a compare and swap loop. There are three cases: + // - cell is empty: CAS into place and return. + // - cell has matching key, return cell`s data. + // - cell has non-matching key: hash collision, probe next cell. + std::atomic &EntryRef = + reinterpret_cast &>(HashTable[Idx]); + MapEntryTy Candidate = EntryRef.load(std::memory_order_acquire); + if (isNull(Candidate)) { + // Create new entry. + if constexpr (KeepDataInsideTable) + DataToInsert = NewData; + else if (isNull(DataToInsert)) + DataToInsert = Info::create(NewData, Alloc...); + + // Try to put new entry into the table. + if (EntryRef.compare_exchange_weak(Candidate, DataToInsert, + std::memory_order_release)) + return std::make_pair(DataToInsert, true); + continue; + } + + if (isEqual(Candidate, NewData)) { + // Check order and overwrite entry if necessary. + if constexpr (KeepDataInsideTable) { + if constexpr (is_detected::value) { + // Check order and overwrite entry if necessary. + if (Info::lessThan(NewData, Candidate)) { + if (EntryRef.compare_exchange_weak(Candidate, NewData, + std::memory_order_release)) + return std::make_pair(NewData, true); + + continue; + } + } + } + + // Return already existing entry. + return std::make_pair(Candidate, false); + } + + // Advance the probe. Wrap around to the beginning if we run off the end. + ++Idx; + + Idx = Idx == this->TableSize ? 0 : Idx; + if (Idx == StartIdx) { + // If this becomes an issue, we could mark failure and rehash from the + // beginning with a bigger table. There is no difference between + // rehashing internally and starting over. + report_fatal_error("hash table is full"); + } + } + llvm_unreachable("left infloop"); + } + + template ::value), + bool>::type = true> + bool isNull(MapEntryTy Data) { + return Data == nullptr; + } + + template ::value && + sizeof(MapEntryTy) == sizeof(uint8_t)), + bool>::type = true> + bool isNull(MapEntryTy Data) { + return *reinterpret_cast(&Data) == static_cast(0); + } + + template ::value && + sizeof(MapEntryTy) == sizeof(uint16_t)), + bool>::type = true> + bool isNull(MapEntryTy Data) { + return *reinterpret_cast(&Data) == static_cast(0); + } + + template ::value && + sizeof(MapEntryTy) == sizeof(uint32_t)), + bool>::type = true> + bool isNull(MapEntryTy Data) { + return *reinterpret_cast(&Data) == static_cast(0); + } + + template ::value && + sizeof(MapEntryTy) == sizeof(uint64_t)), + bool>::type = true> + bool isNull(MapEntryTy Data) { + return *reinterpret_cast(&Data) == static_cast(0); + } + + static inline bool setToNull(MapEntryTy *Data) { + return memset(Data, 0, sizeof(MapEntryTy)); + } + + bool isEqual(MapEntryTy LHS, const KeyTy &RHS) { + return Info::isEqual(Info::getKey(*LHS), RHS); + } + + bool isEqual(MapEntryTy LHS, const KeyDataTy &RHS) { + return Info::isEqual(Info::getKey(LHS), Info::getKey(RHS)); + } + + hash_code getHashCode(const KeyTy &Data) { return Info::getHashValue(Data); } + + hash_code getHashCode(const KeyDataTy &Data) { + return Info::getHashValue(Info::getKey(Data)); + } + + template + using haslessThan = decltype(std::declval().lessThan( + std::declval(), std::declval())); + + size_t getHashMask() { return TableSize - 1; } + + // Pointer to the Hashtable. + MapEntryTy *HashTable = nullptr; + + // Size of Hashtable. + size_t TableSize = 0; + + // This allocator is used to allocate and deallocate hashtable. + AllocatorTy &Allocator; +}; + +/// FixedConcurrentHashTable: +/// + +template class FixedConcurrentHashTableInfo { +public: + static inline hash_code getHashValue(KeyTy Key) { + return std::hash()(Key); + } + static inline bool isEqual(KeyTy LHS, KeyTy RHS) { return LHS == RHS; } + static inline KeyTy getKey(const KeyDataTy &Data) { return Data.key(); } + + // Optional: + // static inline bool lessThan(const KeyDataTy& LHS, const KeyDataTy& RHS) { + // return LHS < RHS; } +}; + +template > +class FixedConcurrentHashTable + : public FixedConcurrentHashTableBase { +public: + FixedConcurrentHashTable(size_t InitialSize, AllocatorTy &Allocator) + : FixedConcurrentHashTableBase( + InitialSize, Allocator) {} + + /// Insert new entry \p NewData or return already existing entry. + /// + /// \returns entry and "true" if an entry is just inserted or + /// "false" if an entry already exists. + std::pair insert(KeyDataTy NewData) { + return FixedConcurrentHashTableBase::insertImpl(NewData); + } +}; + +/// FixedConcurrentHashTableExternalAlloc: +/// +/// Insertions into the internal hash table are lock-free, and +/// allocations are done using an external non-mt-safe data pool. + +template +class FixedConcurrentHashTableInfoExternalAlloc { +public: + static inline hash_code getHashValue(const KeyTy &Key) { + return std::hash()(Key); + } + static inline bool isEqual(const KeyTy &LHS, const KeyTy &RHS) { + return LHS == RHS; + } + static inline KeyTy getKey(const KeyDataTy &Data) { return Data.key(); } + static inline KeyDataTy *create(const KeyTy &Key, AllocatorTy &Allocator) { + return KeyDataTy::create(Key, Allocator); + } +}; + +template > +class FixedConcurrentHashTableExternalAlloc + : public FixedConcurrentHashTableBase { +public: + FixedConcurrentHashTableExternalAlloc(size_t InitialSize, + AllocatorTy &Allocator) + : FixedConcurrentHashTableBase(InitialSize, Allocator) {} + + /// Insert new entry for \p Key or return already existing entry. + /// + /// \returns entry and "true" if an entry is just inserted or + /// "false" if an entry already exists. + std::pair insert(const KeyTy &Key, + AllocatorTy &ThreadLocalAllocator) { + return FixedConcurrentHashTableBase::insertImpl(Key, + ThreadLocalAllocator); + } +}; + +} // end namespace llvm + +#endif // LLVM_ADT_FIXEDCONCURRENTHASHTABLE_H diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -26,6 +26,7 @@ EnumeratedArrayTest.cpp EquivalenceClassesTest.cpp FallibleIteratorTest.cpp + FixedConcurrentHashtableTest.cpp FloatingPointMode.cpp FoldingSet.cpp FunctionExtrasTest.cpp diff --git a/llvm/unittests/ADT/FixedConcurrentHashtableTest.cpp b/llvm/unittests/ADT/FixedConcurrentHashtableTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/FixedConcurrentHashtableTest.cpp @@ -0,0 +1,278 @@ +//===- FixedConcurrentHashtableTest.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/FixedConcurrentHashTable.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Parallel.h" +#include "gtest/gtest.h" +#include +#include +using namespace llvm; + +namespace { + +class Int { +public: + Int() : Data(0x0) {} + uint32_t key() const { return Data & 0x7FFFFFFF; } + + static Int create(uint32_t Data) { return Int(Data | 0x80000000); } + + friend bool operator==(const Int &LHS, const Int &RHS) { + return LHS.Data == RHS.Data; + } + +protected: + Int(uint32_t Data) : Data(Data) {} + + uint32_t Data; +}; + +class alignas(4) OrderedInt { +public: + OrderedInt() : Data(0x0), Order(0) {} + + uint16_t key() const { return Data & 0x7FFF; } + + friend bool operator==(const OrderedInt &LHS, const OrderedInt &RHS) { + return std::tie(LHS.Data, LHS.Order) == std::tie(RHS.Data, RHS.Order); + } + + static OrderedInt create(uint16_t Data, uint16_t Order) { + return OrderedInt(Data | 0x8000, Order); + } + + bool lessThan(const OrderedInt &Other) const { return Order < Other.Order; } + +protected: + OrderedInt(uint16_t Data, uint16_t Order) + : Data(Data | 0x8000), Order(Order) {} + + uint16_t Data; + uint16_t Order; +}; + +class String { +public: + static String *create(const std::string &Key, BumpPtrAllocator &Allocator) { + return new (Allocator) String(Key); + } + + const std::string &key() const { return Data; } + +protected: + String(const std::string &Data) : Data(Data) {} + std::string Data; +}; + +TEST(FixedConcurrentHashTableTest, AddIntEntries) { + BumpPtrAllocator Allocator; + FixedConcurrentHashTable HashTable(10, Allocator); + + std::pair res1 = HashTable.insert(Int::create(1)); + // Check entry is inserted. + EXPECT_TRUE(res1.first.key() == 1); + EXPECT_TRUE(res1.second); + + std::pair res2 = HashTable.insert(Int::create(2)); + // Check old entry is still valid. + EXPECT_TRUE(res1.first.key() == 1); + // Check new entry is inserted. + EXPECT_TRUE(res2.first.key() == 2); + EXPECT_TRUE(res2.second); + // Check new and old entries are not equal. + EXPECT_FALSE(res1.first == res2.first); + + std::pair res3 = HashTable.insert(Int::create(3)); + // Check one more entry is inserted. + EXPECT_TRUE(res3.first.key() == 3); + EXPECT_TRUE(res3.second); + + std::pair res4 = HashTable.insert(Int::create(1)); + // Check duplicated entry is inserted. + EXPECT_TRUE(res4.first.key() == 1); + EXPECT_FALSE(res4.second); + // Check duplicated entry matches with the first one. + EXPECT_TRUE(res1.first == res4.first); + + // Check first entry is still valid. + EXPECT_TRUE(res1.first.key() == 1); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("HashTable statistic:") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("HashTable size: 16") != std::string::npos); + EXPECT_TRUE(StatisticString.find("Load factor: 0.19") != std::string::npos); +} + +TEST(FixedConcurrentHashTableTest, AddOrderedIntEntries) { + class HashTableInfo { + public: + static inline hash_code getHashValue(uint32_t Key) { + return llvm::hash_value(Key); + } + static inline bool isEqual(uint32_t LHS, uint32_t RHS) { + return LHS == RHS; + } + static inline uint32_t getKey(const OrderedInt &Data) { return Data.key(); } + static inline bool lessThan(const OrderedInt &LHS, const OrderedInt &RHS) { + return LHS.lessThan(RHS); + } + }; + + BumpPtrAllocator Allocator; + FixedConcurrentHashTable + HashTable(10, Allocator); + + std::pair res = HashTable.insert(OrderedInt::create(1, 2)); + // Check entry is inserted. + EXPECT_TRUE(res.first == OrderedInt::create(1, 2)); + EXPECT_TRUE(res.second); + + res = HashTable.insert(OrderedInt::create(1, 3)); + // Check entry is unchanged. + EXPECT_TRUE(res.first == OrderedInt::create(1, 2)); + EXPECT_FALSE(res.second); + + res = HashTable.insert(OrderedInt::create(1, 2)); + // Check entry is unchanged. + EXPECT_TRUE(res.first == OrderedInt::create(1, 2)); + EXPECT_FALSE(res.second); + + res = HashTable.insert(OrderedInt::create(1, 1)); + // Check entry is overwritten. + EXPECT_TRUE(res.first == OrderedInt::create(1, 1)); + EXPECT_TRUE(res.second); +} + +TEST(FixedConcurrentHashTableTest, AddIntEntriesParallel) { + BumpPtrAllocator Allocator; + size_t NumElements = 10000; + FixedConcurrentHashTable HashTable(NumElements, Allocator); + + // Check parallel insertion. + parallelFor(0, NumElements, [&](size_t I) { + std::pair Entry = HashTable.insert(Int::create(I)); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first.key() == I); + }); + + // Check parallel insertion of duplicates. + parallelFor(0, NumElements, [&](size_t I) { + std::pair Entry = HashTable.insert(Int::create(I)); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first.key() == I); + }); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("HashTable statistic:") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("HashTable size: 16384") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("Load factor: 0.61") != std::string::npos); +} + +TEST(FixedConcurrentHashTableTest, AddStringEntries) { + BumpPtrAllocator Allocator; + FixedConcurrentHashTableExternalAlloc + HashTable(10, Allocator); + + std::pair res1 = HashTable.insert("1", Allocator); + // Check entry is inserted. + EXPECT_TRUE(res1.first->key() == "1"); + EXPECT_TRUE(res1.second); + + std::pair res2 = HashTable.insert("2", Allocator); + // Check old entry is still valid. + EXPECT_TRUE(res1.first->key() == "1"); + // Check new entry is inserted. + EXPECT_TRUE(res2.first->key() == "2"); + EXPECT_TRUE(res2.second); + // Check new and old entries use different memory. + EXPECT_TRUE(res1.first != res2.first); + + std::pair res3 = HashTable.insert("3", Allocator); + // Check one more entry is inserted. + EXPECT_TRUE(res3.first->key() == "3"); + EXPECT_TRUE(res3.second); + + std::pair res4 = HashTable.insert("1", Allocator); + // Check duplicated entry is inserted. + EXPECT_TRUE(res4.first->key() == "1"); + EXPECT_FALSE(res4.second); + // Check duplicated entry matches with the first one. + EXPECT_TRUE(res1.first == res4.first); + + // Check first entry is still valid. + EXPECT_TRUE(res1.first->key() == "1"); + + // Check data was allocated by allocator. + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("HashTable statistic:") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("HashTable size: 16") != std::string::npos); + EXPECT_TRUE(StatisticString.find("Load factor: 0.19") != std::string::npos); +} + +TEST(FixedConcurrentHashTableTest, AddStringEntriesParallel) { + const size_t NumElements = 10000; + static LLVM_THREAD_LOCAL BumpPtrAllocator Allocator; + FixedConcurrentHashTableExternalAlloc + HashTable(NumElements, Allocator); + + // Check parallel insertion. + parallelFor(0, NumElements, [&](size_t I) { + std::string StringForElement = formatv("{0}", I); + std::pair Entry = + HashTable.insert(StringForElement, Allocator); + EXPECT_TRUE(Entry.second); + EXPECT_TRUE(Entry.first->key() == StringForElement); + EXPECT_TRUE(Allocator.getBytesAllocated() > 0); + }); + + // Check parallel insertion of duplicates. + parallelFor(0, NumElements, [&](size_t I) { + size_t BytesAllocated = Allocator.getBytesAllocated(); + std::string StringForElement = formatv("{0}", I); + std::pair Entry = + HashTable.insert(StringForElement, Allocator); + EXPECT_FALSE(Entry.second); + EXPECT_TRUE(Entry.first->key() == StringForElement); + // Check no additional bytes were allocated for duplicate. + EXPECT_TRUE(Allocator.getBytesAllocated() == BytesAllocated); + }); + + // Check statistic. + std::string StatisticString; + raw_string_ostream StatisticStream(StatisticString); + HashTable.printStatistic(StatisticStream); + + EXPECT_TRUE(StatisticString.find("HashTable statistic:") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("HashTable size: 16384") != + std::string::npos); + EXPECT_TRUE(StatisticString.find("Load factor: 0.61") != std::string::npos); +} + +} // end anonymous namespace