diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -82,9 +82,6 @@ /// The set containing the allocated storage instances. StorageTypeSet instances; - /// Allocator to use when constructing derived instances. - StorageAllocator allocator; - #if LLVM_ENABLE_THREADS != 0 /// A mutex to keep uniquing thread-safe. llvm::sys::SmartRWMutex mutex; @@ -95,11 +92,12 @@ /// fashion. BaseStorage * getOrCreateUnsafe(Shard &shard, LookupKey &key, - function_ref ctorFn) { + function_ref ctorFn, + function_ref getAllocFn) { auto existing = shard.instances.insert_as({key.hashValue}, key); BaseStorage *&storage = existing.first->storage; if (existing.second) - storage = ctorFn(shard.allocator); + storage = ctorFn(getAllocFn()); return storage; } @@ -138,11 +136,12 @@ BaseStorage * getOrCreate(bool threadingIsEnabled, unsigned hashValue, function_ref isEqual, - function_ref ctorFn) { + function_ref ctorFn, + function_ref getAllocFn) { Shard &shard = getShard(hashValue); ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual}; if (!threadingIsEnabled) - return getOrCreateUnsafe(shard, lookupKey, ctorFn); + return getOrCreateUnsafe(shard, lookupKey, ctorFn, getAllocFn); // Check for a instance of this object in the local cache. auto localIt = localCache->insert_as({hashValue}, lookupKey); @@ -161,19 +160,24 @@ // Acquire a writer-lock so that we can safely create the new storage // instance. llvm::sys::SmartScopedWriter typeLock(shard.mutex); - return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn); + return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn, getAllocFn); } + /// Run a mutation function on the provided storage object in a thread-safe /// way. LogicalResult - mutate(bool threadingIsEnabled, BaseStorage *storage, + mutate(bool threadingIsEnabled, StorageAllocator &allocator, + BaseStorage *storage, function_ref mutationFn) { - Shard &shard = getShardFor(storage); + // Get a shard to use for mutating this storage instance. It doesn't need to + // be the same shard as the original allocation, but does need to be + // deterministic. + Shard &shard = getShard(llvm::hash_value(storage)); if (!threadingIsEnabled) - return mutationFn(shard.allocator); + return mutationFn(allocator); llvm::sys::SmartScopedWriter lock(shard.mutex); - return mutationFn(shard.allocator); + return mutationFn(allocator); } private: @@ -197,18 +201,6 @@ return *shard; } - /// Return the shard that allocated the provided storage object. - Shard &getShardFor(BaseStorage *storage) { - for (size_t i = 0; i != numShards; ++i) { - if (Shard *shard = shards[i].load(std::memory_order_acquire)) { - llvm::sys::SmartScopedReader lock(shard->mutex); - if (shard->allocator.allocated(storage)) - return *shard; - } - } - llvm_unreachable("expected storage object to have a valid shard"); - } - /// A thread local cache for storage objects. This helps to reduce the lock /// contention when an object already existing in the cache. ThreadLocalCache localCache; @@ -281,8 +273,9 @@ assert(parametricUniquers.count(id) && "creating unregistered storage instance"); ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; - return storageUniquer.getOrCreate(threadingIsEnabled, hashValue, isEqual, - ctorFn); + return storageUniquer.getOrCreate( + threadingIsEnabled, hashValue, isEqual, ctorFn, + [&]() -> StorageAllocator & { return getThreadSafeAllocator(); }); } /// Run a mutation function on the provided storage object in a thread-safe @@ -293,7 +286,30 @@ assert(parametricUniquers.count(id) && "mutating unregistered storage instance"); ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; - return storageUniquer.mutate(threadingIsEnabled, storage, mutationFn); + return storageUniquer.mutate(threadingIsEnabled, getThreadSafeAllocator(), + storage, mutationFn); + } + + /// Return an allocator that can be used to safely allocate instances on the + /// current thread. + StorageAllocator &getThreadSafeAllocator() { +#if LLVM_ENABLE_THREADS != 0 + StorageAllocator *&allocator = threadSafeAllocator.get(); + + // If the allocator has not been initialized, create a new one. + if (!allocator) { + allocator = new StorageAllocator(); + + // Record this allocator, given that we don't want it to be destroyed when + // the thread dies. + llvm::sys::SmartScopedLock lock(threadAllocatorMutex); + threadAllocators.push_back(std::unique_ptr(allocator)); + } + + return *allocator; +#else + return allocator; +#endif } //===--------------------------------------------------------------------===// @@ -314,6 +330,18 @@ // Instance Storage //===--------------------------------------------------------------------===// +#if LLVM_ENABLE_THREADS != 0 + /// A thread local set of allocators used for uniquing parametric instances, + /// or other data allocated in thread volatile situations. + ThreadLocalCache threadSafeAllocator; + + /// All of the allocators that have been created for thread based allocation. + std::vector> threadAllocators; + + /// A mutex used for safely adding a new thread allocator. + llvm::sys::SmartMutex threadAllocatorMutex; +#endif + /// Map of type ids to the storage uniquer to use for registered objects. DenseMap> parametricUniquers; @@ -322,8 +350,9 @@ /// singleton. DenseMap singletonInstances; - /// Allocator used for uniquing singleton instances. - StorageAllocator singletonAllocator; + /// Main allocator used for uniquing singleton instances, and other state when + /// thread safety is guaranteed. + StorageAllocator allocator; /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; @@ -378,7 +407,7 @@ TypeID id, function_ref ctorFn) { assert(!impl->singletonInstances.count(id) && "storage class already registered"); - impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator)); + impl->singletonInstances.try_emplace(id, ctorFn(impl->allocator)); } /// Implementation for mutating an instance of a derived storage.