diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -82,9 +82,6 @@ /// The set containing the allocated storage instances. StorageTypeSet instances; - /// Allocator to use when constructing derived instances. - StorageAllocator allocator; - #if LLVM_ENABLE_THREADS != 0 /// A mutex to keep uniquing thread-safe. llvm::sys::SmartRWMutex mutex; @@ -93,13 +90,12 @@ /// Get or create an instance of a param derived type in an thread-unsafe /// fashion. - BaseStorage * - getOrCreateUnsafe(Shard &shard, LookupKey &key, - function_ref ctorFn) { + BaseStorage *getOrCreateUnsafe(Shard &shard, LookupKey &key, + function_ref ctorFn) { auto existing = shard.instances.insert_as({key.hashValue}, key); BaseStorage *&storage = existing.first->storage; if (existing.second) - storage = ctorFn(shard.allocator); + storage = ctorFn(); return storage; } @@ -135,10 +131,9 @@ } } /// Get or create an instance of a parametric type. - BaseStorage * - getOrCreate(bool threadingIsEnabled, unsigned hashValue, - function_ref isEqual, - function_ref ctorFn) { + BaseStorage *getOrCreate(bool threadingIsEnabled, unsigned hashValue, + function_ref isEqual, + function_ref ctorFn) { Shard &shard = getShard(hashValue); ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual}; if (!threadingIsEnabled) @@ -163,17 +158,20 @@ llvm::sys::SmartScopedWriter typeLock(shard.mutex); return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn); } + /// Run a mutation function on the provided storage object in a thread-safe /// way. - LogicalResult - mutate(bool threadingIsEnabled, BaseStorage *storage, - function_ref mutationFn) { - Shard &shard = getShardFor(storage); + LogicalResult mutate(bool threadingIsEnabled, BaseStorage *storage, + function_ref mutationFn) { if (!threadingIsEnabled) - return mutationFn(shard.allocator); + return mutationFn(); + // Get a shard to use for mutating this storage instance. It doesn't need to + // be the same shard as the original allocation, but does need to be + // deterministic. + Shard &shard = getShard(llvm::hash_value(storage)); llvm::sys::SmartScopedWriter lock(shard.mutex); - return mutationFn(shard.allocator); + return mutationFn(); } private: @@ -197,18 +195,6 @@ return *shard; } - /// Return the shard that allocated the provided storage object. - Shard &getShardFor(BaseStorage *storage) { - for (size_t i = 0; i != numShards; ++i) { - if (Shard *shard = shards[i].load(std::memory_order_acquire)) { - llvm::sys::SmartScopedReader lock(shard->mutex); - if (shard->allocator.allocated(storage)) - return *shard; - } - } - llvm_unreachable("expected storage object to have a valid shard"); - } - /// A thread local cache for storage objects. This helps to reduce the lock /// contention when an object already existing in the cache. ThreadLocalCache localCache; @@ -281,8 +267,9 @@ assert(parametricUniquers.count(id) && "creating unregistered storage instance"); ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; - return storageUniquer.getOrCreate(threadingIsEnabled, hashValue, isEqual, - ctorFn); + return storageUniquer.getOrCreate( + threadingIsEnabled, hashValue, isEqual, + [&] { return ctorFn(getThreadSafeAllocator()); }); } /// Run a mutation function on the provided storage object in a thread-safe @@ -293,7 +280,34 @@ assert(parametricUniquers.count(id) && "mutating unregistered storage instance"); ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; - return storageUniquer.mutate(threadingIsEnabled, storage, mutationFn); + return storageUniquer.mutate(threadingIsEnabled, storage, [&] { + return mutationFn(getThreadSafeAllocator()); + }); + } + + /// Return an allocator that can be used to safely allocate instances on the + /// current thread. + StorageAllocator &getThreadSafeAllocator() { +#if LLVM_ENABLE_THREADS != 0 + if (!threadingIsEnabled) + return allocator; + + // If the allocator has not been initialized, create a new one. + StorageAllocator *&threadAllocator = threadSafeAllocator.get(); + if (!threadAllocator) { + threadAllocator = new StorageAllocator(); + + // Record this allocator, given that we don't want it to be destroyed when + // the thread dies. + llvm::sys::SmartScopedLock lock(threadAllocatorMutex); + threadAllocators.push_back( + std::unique_ptr(threadAllocator)); + } + + return *threadAllocator; +#else + return allocator; +#endif } //===--------------------------------------------------------------------===// @@ -314,6 +328,22 @@ // Instance Storage //===--------------------------------------------------------------------===// +#if LLVM_ENABLE_THREADS != 0 + /// A thread local set of allocators used for uniquing parametric instances, + /// or other data allocated in thread volatile situations. + ThreadLocalCache threadSafeAllocator; + + /// All of the allocators that have been created for thread based allocation. + std::vector> threadAllocators; + + /// A mutex used for safely adding a new thread allocator. + llvm::sys::SmartMutex threadAllocatorMutex; +#endif + + /// Main allocator used for uniquing singleton instances, and other state when + /// thread safety is guaranteed. + StorageAllocator allocator; + /// Map of type ids to the storage uniquer to use for registered objects. DenseMap> parametricUniquers; @@ -322,9 +352,6 @@ /// singleton. DenseMap singletonInstances; - /// Allocator used for uniquing singleton instances. - StorageAllocator singletonAllocator; - /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; }; @@ -378,7 +405,7 @@ TypeID id, function_ref ctorFn) { assert(!impl->singletonInstances.count(id) && "storage class already registered"); - impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator)); + impl->singletonInstances.try_emplace(id, ctorFn(impl->allocator)); } /// Implementation for mutating an instance of a derived storage.