diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h --- a/mlir/include/mlir/Support/ThreadLocalCache.h +++ b/mlir/include/mlir/Support/ThreadLocalCache.h @@ -25,12 +25,40 @@ /// cache has very large lock contention. template class ThreadLocalCache { + // Keep a separate shared_ptr protected state that can be acquired atomically + // instead of using shared_ptr's for each value. This avoids a problem + // where the instance shared_ptr is locked() successfully, and then the + // ThreadLocalCache gets destroyed before remove() can be called successfully. + struct PerInstanceState { + /// Remove the given value entry. This is generally called when a thread + /// local cache is destructing. + void remove(ValueT *value) { + // Erase the found value directly, because it is guaranteed to be in the + // list. + llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); + auto it = + llvm::find_if(instances, [&](std::unique_ptr &instance) { + return instance.get() == value; + }); + assert(it != instances.end() && "expected value to exist in cache"); + instances.erase(it); + } + + /// Owning pointers to all of the values that have been constructed for this + /// object in the static cache. + SmallVector, 1> instances; + + /// A mutex used when a new thread instance has been added to the cache for + /// this object. + llvm::sys::SmartMutex instanceMutex; + }; + /// The type used for the static thread_local cache. This is a map between an /// instance of the non-static cache and a weak reference to an instance of /// ValueT. We use a weak reference here so that the object can be destroyed /// without needing to lock access to the cache itself. - struct CacheType : public llvm::SmallDenseMap *, - std::weak_ptr> { + struct CacheType + : public llvm::SmallDenseMap> { ~CacheType() { // Remove the values of this cache that haven't already expired. for (auto &it : *this) @@ -60,15 +88,16 @@ ValueT &get() { // Check for an already existing instance for this thread. CacheType &staticCache = getStaticCache(); - std::weak_ptr &threadInstance = staticCache[this]; + std::weak_ptr &threadInstance = staticCache[perInstanceState.get()]; if (std::shared_ptr value = threadInstance.lock()) return *value; // Otherwise, create a new instance for this thread. - llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); - instances.push_back(std::make_shared()); - std::shared_ptr &instance = instances.back(); - threadInstance = instance; + llvm::sys::SmartScopedLock threadInstanceLock( + perInstanceState->instanceMutex); + perInstanceState->instances.push_back(std::make_unique()); + ValueT *instance = perInstanceState->instances.back().get(); + threadInstance = std::shared_ptr(perInstanceState, instance); // Before returning the new instance, take the chance to clear out any used // entries in the static map. The cache is only cleared within the same @@ -90,26 +119,8 @@ return cache; } - /// Remove the given value entry. This is generally called when a thread local - /// cache is destructing. - void remove(ValueT *value) { - // Erase the found value directly, because it is guaranteed to be in the - // list. - llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); - auto it = llvm::find_if(instances, [&](std::shared_ptr &instance) { - return instance.get() == value; - }); - assert(it != instances.end() && "expected value to exist in cache"); - instances.erase(it); - } - - /// Owning pointers to all of the values that have been constructed for this - /// object in the static cache. - SmallVector, 1> instances; - - /// A mutex used when a new thread instance has been added to the cache for - /// this object. - llvm::sys::SmartMutex instanceMutex; + std::shared_ptr perInstanceState = + std::make_shared(); }; } // namespace mlir diff --git a/mlir/test/CAPI/pdl.c b/mlir/test/CAPI/pdl.c --- a/mlir/test/CAPI/pdl.c +++ b/mlir/test/CAPI/pdl.c @@ -333,5 +333,6 @@ testRangeType(ctx); testTypeType(ctx); testValueType(ctx); + mlirContextDestroy(ctx); return EXIT_SUCCESS; } diff --git a/mlir/test/CAPI/transform.c b/mlir/test/CAPI/transform.c --- a/mlir/test/CAPI/transform.c +++ b/mlir/test/CAPI/transform.c @@ -84,5 +84,6 @@ mlirDialectHandleRegisterDialect(mlirGetDialectHandle__transform__(), ctx); testAnyOpType(ctx); testOperationType(ctx); + mlirContextDestroy(ctx); return EXIT_SUCCESS; }