diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -38,6 +38,22 @@ class TypeConverter { public: virtual ~TypeConverter() = default; + TypeConverter() = default; + // Copy the registered conversions, but not the caches + TypeConverter(const TypeConverter &other) + : conversions(other.conversions), + argumentMaterializations(other.argumentMaterializations), + sourceMaterializations(other.sourceMaterializations), + targetMaterializations(other.targetMaterializations), + typeAttributeConversions(other.typeAttributeConversions) {} + TypeConverter &operator=(const TypeConverter &other) { + conversions = other.conversions; + argumentMaterializations = other.argumentMaterializations; + sourceMaterializations = other.sourceMaterializations; + targetMaterializations = other.targetMaterializations; + typeAttributeConversions = other.typeAttributeConversions; + return *this; + } /// This class provides all of the information necessary to convert a type /// signature. @@ -421,6 +437,8 @@ mutable DenseMap cachedDirectConversions; /// This cache stores the successful 1->N conversions, where N != 1. mutable DenseMap> cachedMultiConversions; + /// A mutex used for cache access + mutable llvm::sys::SmartRWMutex cacheMutex; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2920,24 +2920,34 @@ LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl &results) const { - auto existingIt = cachedDirectConversions.find(t); - if (existingIt != cachedDirectConversions.end()) { - if (existingIt->second) - results.push_back(existingIt->second); - return success(existingIt->second != nullptr); - } - auto multiIt = cachedMultiConversions.find(t); - if (multiIt != cachedMultiConversions.end()) { - results.append(multiIt->second.begin(), multiIt->second.end()); - return success(); + { + std::shared_lock cacheReadLock(cacheMutex, + std::defer_lock); + if (t.getContext()->isMultithreadingEnabled()) + cacheReadLock.lock(); + auto existingIt = cachedDirectConversions.find(t); + if (existingIt != cachedDirectConversions.end()) { + if (existingIt->second) + results.push_back(existingIt->second); + return success(existingIt->second != nullptr); + } + auto multiIt = cachedMultiConversions.find(t); + if (multiIt != cachedMultiConversions.end()) { + results.append(multiIt->second.begin(), multiIt->second.end()); + return success(); + } } - // Walk the added converters in reverse order to apply the most recently // registered first. size_t currentCount = results.size(); + std::unique_lock cacheWriteLock(cacheMutex, + std::defer_lock); + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { if (std::optional result = converter(t, results)) { + if (t.getContext()->isMultithreadingEnabled()) + cacheWriteLock.lock(); if (!succeeded(*result)) { cachedDirectConversions.try_emplace(t, nullptr); return failure();