diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Support/ThreadLocalCache.h @@ -0,0 +1,117 @@ +//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a definition of the ThreadLocalCache class. This class +// provides support for defining thread local objects with non-static duration. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H +#define MLIR_SUPPORT_THREADLOCALCACHE_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/Mutex.h" +#include "llvm/Support/ThreadLocal.h" + +namespace mlir { +/// This class provides support for defining a thread local object with non +/// static storage duration. This is very useful for situations in which a data +/// cache has very large lock contention. +template +class ThreadLocalCache { + /// The type used for the static thread_local cache. This is a map between an + /// instance of the non-static cache and a weak reference to an instance of + /// ValueT. We use a weak reference here so that the object can be destroyed + /// without needing to lock access to the cache itself. + struct CacheType : public llvm::SmallDenseMap *, + std::weak_ptr> { + ~CacheType() { + // Remove the values of this cache that haven't already expired. + for (auto &it : *this) + if (std::shared_ptr value = it.second.lock()) + it.first->remove(value.get()); + } + + /// Clear out any unused entries within the map. This method is not + /// thread-safe, and should only be called by the same thread as the cache. + void clearExpiredEntries() { + for (auto it = this->begin(), e = this->end(); it != e;) { + auto curIt = it++; + if (curIt->second.expired()) + this->erase(curIt); + } + } + }; + +public: + ThreadLocalCache() = default; + ~ThreadLocalCache() { + // No cleanup is necessary here as the shared_pointer memory will go out of + // scope and invalidate the weak pointers held by the thread_local caches. + } + + /// Return an instance of the value type for the current thread. + ValueT &get() { + // Check for an already existing instance for this thread. + CacheType &staticCache = getStaticCache(); + std::weak_ptr &threadInstance = staticCache[this]; + if (std::shared_ptr value = threadInstance.lock()) + return *value; + + // Otherwise, create a new instance for this thread. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + instances.push_back(std::make_shared()); + std::shared_ptr &instance = instances.back(); + threadInstance = instance; + + // Before returning the new instance, take the chance to clear out any used + // entries in the static map. The cache is only cleared within the same + // thread to remove the need to lock the cache itself. + staticCache.clearExpiredEntries(); + return *instance; + } + ValueT &operator*() { return get(); } + ValueT *operator->() { return &get(); } + +private: + ThreadLocalCache(ThreadLocalCache &&) = delete; + ThreadLocalCache(const ThreadLocalCache &) = delete; + ThreadLocalCache &operator=(const ThreadLocalCache &) = delete; + + /// Return the static thread local instance of the cache type. + static CacheType &getStaticCache() { + static LLVM_THREAD_LOCAL CacheType cache; + return cache; + } + + /// Remove the given value entry. This is generally called when a thread local + /// cache is destructing. + void remove(ValueT *value) { + // Erase the found value directly, because it is guaranteed to be in the + // list. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + auto it = llvm::find_if(instances, [&](std::shared_ptr &instance) { + return instance.get() == value; + }); + assert(it != instances.end() && "expected value to exist in cache"); + instances.erase(it); + } + + /// Owning pointers to all of the values that have been constructed for this + /// object in the static cache. + SmallVector, 1> instances; + + /// A mutex used when a new thread instance has been added to the cache for + /// this object. + llvm::sys::SmartMutex instanceMutex; +}; +} // end namespace mlir + +#endif // MLIR_SUPPORT_THREADLOCALCACHE_H diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" +#include "mlir/Support/ThreadLocalCache.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -278,8 +279,12 @@ /// operations. llvm::StringMap registeredOperations; - /// These are identifiers uniqued into this MLIRContext. + /// Identifers are uniqued by string value and use the internal string set for + /// storage. llvm::StringSet identifiers; + /// A thread local cache of identifiers to reduce lock contention. + ThreadLocalCache *>> + localIdentifierCache; /// An allocator used for AbstractAttribute and AbstractType objects. llvm::BumpPtrAllocator abstractDialectSymbolAllocator; @@ -619,27 +624,37 @@ /// Return an identifier for the specified string. Identifier Identifier::get(StringRef str, MLIRContext *context) { + // Check invariants after seeing if we already have something in the + // identifier table - if we already had it in the table, then it already + // passed invariant checks. + assert(!str.empty() && "Cannot create an empty identifier"); + assert(str.find('\0') == StringRef::npos && + "Cannot create an identifier with a nul character"); + auto &impl = context->getImpl(); + if (!context->isMultithreadingEnabled()) + return Identifier(&*impl.identifiers.insert(str).first); + + // Check for an existing instance in the local cache. + auto *&localEntry = (*impl.localIdentifierCache)[str]; + if (localEntry) + return Identifier(localEntry); // Check for an existing identifier in read-only mode. if (context->isMultithreadingEnabled()) { llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); auto it = impl.identifiers.find(str); - if (it != impl.identifiers.end()) - return Identifier(&*it); + if (it != impl.identifiers.end()) { + localEntry = &*it; + return Identifier(localEntry); + } } - // Check invariants after seeing if we already have something in the - // identifier table - if we already had it in the table, then it already - // passed invariant checks. - assert(!str.empty() && "Cannot create an empty identifier"); - assert(str.find('\0') == StringRef::npos && - "Cannot create an identifier with a nul character"); - // Acquire a writer-lock so that we can safely create the new instance. - ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled); + llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); auto it = impl.identifiers.insert(str).first; - return Identifier(&*it); + localEntry = &*it; + return Identifier(localEntry); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -9,6 +9,7 @@ #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/ThreadLocalCache.h" #include "mlir/Support/TypeID.h" #include "llvm/Support/RWMutex.h" @@ -39,6 +40,8 @@ /// A utility wrapper object representing a hashed storage object. This class /// contains a storage object and an existing computed hash value. struct HashedStorage { + HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr) + : hashValue(hashValue), storage(storage) {} unsigned hashValue; BaseStorage *storage; }; @@ -46,10 +49,10 @@ /// Storage info for derived TypeStorage objects. struct StorageKeyInfo : DenseMapInfo { static HashedStorage getEmptyKey() { - return HashedStorage{0, DenseMapInfo::getEmptyKey()}; + return HashedStorage(0, DenseMapInfo::getEmptyKey()); } static HashedStorage getTombstoneKey() { - return HashedStorage{0, DenseMapInfo::getTombstoneKey()}; + return HashedStorage(0, DenseMapInfo::getTombstoneKey()); } static unsigned getHashValue(const HashedStorage &key) { @@ -102,25 +105,34 @@ if (!threadingIsEnabled) return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); + // Check for a instance of this object in the local cache. + auto localIt = complexStorageLocalCache->insert_as( + InstSpecificUniquer::HashedStorage(lookupKey.hashValue), lookupKey); + BaseStorage *&localInst = localIt.first->storage; + if (localInst) + return localInst; + // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); auto it = storageUniquer.complexInstances.find_as(lookupKey); if (it != storageUniquer.complexInstances.end()) - return it->storage; + return localInst = it->storage; } // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); + return localInst = + getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); } /// Get or create an instance of a complex derived type in an thread-unsafe /// fashion. BaseStorage * getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind, - InstSpecificUniquer::LookupKey &lookupKey, + InstSpecificUniquer::LookupKey &key, function_ref ctorFn) { - auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey); + auto existing = + storageUniquer.complexInstances.insert_as({key.hashValue}, key); if (!existing.second) return existing.first->storage; @@ -128,9 +140,7 @@ // instance. BaseStorage *storage = initializeStorage(kind, storageUniquer.allocator, ctorFn); - *existing.first = - InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage}; - return storage; + return existing.first->storage = storage; } /// Get or create an instance of a simple derived type. @@ -142,6 +152,11 @@ if (!threadingIsEnabled) return getOrCreateUnsafe(storageUniquer, kind, ctorFn); + // Check for a instance of this object in the local cache. + BaseStorage *&localInst = (*simpleStorageLocalCache)[kind]; + if (localInst) + return localInst; + // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); @@ -152,7 +167,7 @@ // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, kind, ctorFn); + return localInst = getOrCreateUnsafe(storageUniquer, kind, ctorFn); } /// Get or create an instance of a simple derived type in an thread-unsafe /// fashion. @@ -215,6 +230,12 @@ /// Map of type ids to the storage uniquer to use for registered objects. DenseMap> instUniquers; + /// A thread local cache for simple and complex storage objects. This helps to + /// reduce the lock contention when an object already existing in the cache. + ThreadLocalCache> simpleStorageLocalCache; + ThreadLocalCache + complexStorageLocalCache; + /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; }; diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -// RUN: mlir-edsc-builder-api-test | FileCheck %s +// RUN: mlir-edsc-builder-api-test #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h"