diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Support/ThreadLocalCache.h @@ -0,0 +1,101 @@ +//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a definition of the ThreadLocalCache class. This class +// provdes support for defining thread local objects with non-static duration. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H +#define MLIR_SUPPORT_THREADLOCALCACHE_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Support/ThreadLocal.h" + +namespace mlir { +/// This class provides support for defining a thread local object with non +/// static storage duration. This is very useful for situations in which a data +/// cache has very large lock contention. +template +class ThreadLocalCache { + /// The type used for the static thread_local cache. This is a map between an + /// instance of the non-static cache and a weak reference to an instance of + /// ValueT. We use a weak reference here so that the object can be destroyed + /// without needing to lock access to the cache itself. + struct CacheType : public llvm::SmallDenseMap *, + std::weak_ptr> { + ~CacheType() { + // Remove the values of this cache that haven't already expired. + for (auto &it : *this) + if (!it.second.expired()) + it.first->remove(it.second.lock().get()); + } + }; + +public: + ThreadLocalCache() = default; + ~ThreadLocalCache() { + // No cleanup is necessary here as the shared_pointer memory will go out of + // scope and invalidate the weak pointers held by the thread_local caches. + } + + /// Return an instance of the value type for the current thread. + ValueT &get() { + // Check for an already existing instance for this thread. + std::weak_ptr &threadInstance = getStaticCache()[this]; + if (!threadInstance.expired()) + return *threadInstance.lock(); + + // Otherwise, create a new instance for this thread. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + instances.push_back(std::make_shared()); + std::shared_ptr &instance = instances.back(); + threadInstance = instance; + return *instance; + } + ValueT &operator*() { return get(); } + ValueT *operator->() { return &get(); } + +private: + ThreadLocalCache(ThreadLocalCache &&) = delete; + ThreadLocalCache(const ThreadLocalCache &) = delete; + ThreadLocalCache &operator=(const ThreadLocalCache &) = delete; + + /// Return the static thread local instance of the cache type. + static CacheType &getStaticCache() { + static LLVM_THREAD_LOCAL CacheType cache; + return cache; + } + + /// Remove the given value entry. This is generally called when a thread local + /// cache is destructing. + void remove(ValueT *value) { + // Erase the found value directly, because it is guaranteed to be in the + // list. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + auto it = llvm::find_if(instances, [&](std::shared_ptr &instance) { + return instance.get() == value; + }); + assert(it != instances.end() && "expected value to exist in cache"); + instances.erase(it); + } + + /// Owning pointers to all of the values that have been constructed for this + /// object in the static cache. + SmallVector, 1> instances; + + /// A mutex used when a new thread instance has been added to the cache for + /// this object. + llvm::sys::SmartMutex instanceMutex; +}; +} // end namespace mlir + +#endif // MLIR_SUPPORT_THREADLOCALCACHE_H diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" +#include "mlir/Support/ThreadLocalCache.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -278,8 +279,12 @@ /// operations. llvm::StringMap registeredOperations; - /// These are identifiers uniqued into this MLIRContext. + /// Identifers are uniqued by string value and use the internal string set for + /// storage. llvm::StringSet identifiers; + /// A thread local cache of identifiers to reduce lock contention. + ThreadLocalCache *>> + localIdentifierCache; /// An allocator used for AbstractAttribute and AbstractType objects. llvm::BumpPtrAllocator abstractDialectSymbolAllocator; @@ -619,27 +624,34 @@ /// Return an identifier for the specified string. Identifier Identifier::get(StringRef str, MLIRContext *context) { + // Check invariants after seeing if we already have something in the + // identifier table - if we already had it in the table, then it already + // passed invariant checks. + assert(!str.empty() && "Cannot create an empty identifier"); + assert(str.find('\0') == StringRef::npos && + "Cannot create an identifier with a nul character"); + auto &impl = context->getImpl(); + if (!context->isMultithreadingEnabled()) + return Identifier(&*impl.identifiers.insert(str).first); + + // Check for an existing instance in the local cache. + auto *&localEntry = (*impl.localIdentifierCache)[str]; + if (localEntry) + return Identifier(localEntry); // Check for an existing identifier in read-only mode. if (context->isMultithreadingEnabled()) { llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); auto it = impl.identifiers.find(str); if (it != impl.identifiers.end()) - return Identifier(&*it); + return Identifier(localEntry = &*it); } - // Check invariants after seeing if we already have something in the - // identifier table - if we already had it in the table, then it already - // passed invariant checks. - assert(!str.empty() && "Cannot create an empty identifier"); - assert(str.find('\0') == StringRef::npos && - "Cannot create an identifier with a nul character"); - // Acquire a writer-lock so that we can safely create the new instance. - ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled); + llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); auto it = impl.identifiers.insert(str).first; - return Identifier(&*it); + return Identifier(localEntry = &*it); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -9,6 +9,7 @@ #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/ThreadLocalCache.h" #include "mlir/Support/TypeID.h" #include "llvm/Support/RWMutex.h" @@ -35,6 +36,8 @@ /// A utility wrapper object representing a hashed storage object. This class /// contains a storage object and an existing computed hash value. struct HashedStorage { + HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr) + : hashValue(hashValue), storage(storage) {} unsigned hashValue; BaseStorage *storage; }; @@ -42,10 +45,10 @@ /// Storage info for derived TypeStorage objects. struct StorageKeyInfo : DenseMapInfo { static HashedStorage getEmptyKey() { - return HashedStorage{0, DenseMapInfo::getEmptyKey()}; + return HashedStorage(0, DenseMapInfo::getEmptyKey()); } static HashedStorage getTombstoneKey() { - return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; + return HashedStorage(0, DenseMapInfo::getTombstoneKey()); } static unsigned getHashValue(const HashedStorage &key) { @@ -98,24 +101,33 @@ if (!threadingIsEnabled) return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); + // Check for a instance of this object in the local cache. + auto localIt = complexStorageLocalCache->insert_as( + InstSpecificUniquer::HashedStorage(lookupKey.hashValue), lookupKey); + BaseStorage *&localInst = localIt.first->storage; + if (localInst) + return localInst; + // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); auto it = storageUniquer.complexInstances.find_as(lookupKey); if (it != storageUniquer.complexInstances.end()) - return it->storage; + return localInst = it->storage; } // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); + return localInst = + getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); } /// Get or create an instance of a complex derived type in an unsafe fashion. BaseStorage * getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind, - InstSpecificUniquer::LookupKey &lookupKey, + InstSpecificUniquer::LookupKey &key, function_ref ctorFn) { - auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey); + auto existing = + storageUniquer.complexInstances.insert_as({key.hashValue}, key); if (!existing.second) return existing.first->storage; @@ -123,9 +135,7 @@ // instance. BaseStorage *storage = initializeStorage(kind, storageUniquer.allocator, ctorFn); - *existing.first = - InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage}; - return storage; + return existing.first->storage = storage; } /// Get or create an instance of a simple derived type. @@ -137,6 +147,11 @@ if (!threadingIsEnabled) return getOrCreateUnsafe(storageUniquer, kind, ctorFn); + // Check for a instance of this object in the local cache. + BaseStorage *&localInst = (*simpleStorageLocalCache)[kind]; + if (localInst) + return localInst; + // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); @@ -147,7 +162,7 @@ // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, kind, ctorFn); + return localInst = getOrCreateUnsafe(storageUniquer, kind, ctorFn); } /// Get or create an instance of a simple derived type in an unsafe fashion. BaseStorage * @@ -209,6 +224,12 @@ /// Map of type ids to the storage uniquer to use for registered objects. DenseMap> instUniquers; + /// A thread local cache for simple and complex storage objects. This helps to + /// reduce the lock contention when an object already existing in the cache. + ThreadLocalCache> simpleStorageLocalCache; + ThreadLocalCache + complexStorageLocalCache; + /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; };