diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h --- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h +++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h @@ -57,8 +57,7 @@ // fir.type --> llvm<"%name = { ty... }"> std::optional convertRecordType(fir::RecordType derived, - llvm::SmallVectorImpl &results, - llvm::ArrayRef callStack) const; + llvm::SmallVectorImpl &results); // Is an extended descriptor needed given the element type of a fir.box type ? // Extended descriptors are required for derived types. diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp --- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp @@ -21,6 +21,7 @@ #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" namespace fir { @@ -81,11 +82,10 @@ }); addConversion( [&](fir::PointerType pointer) { return convertPointerLike(pointer); }); - addConversion([&](fir::RecordType derived, - llvm::SmallVectorImpl &results, - llvm::ArrayRef callStack) { - return convertRecordType(derived, results, callStack); - }); + addConversion( + [&](fir::RecordType derived, llvm::SmallVectorImpl &results) { + return convertRecordType(derived, results); + }); addConversion( [&](fir::RealType real) { return convertRealType(real.getFKind()); }); addConversion( @@ -167,14 +167,19 @@ // fir.type --> llvm<"%name = { ty... }"> std::optional LLVMTypeConverter::convertRecordType( - fir::RecordType derived, llvm::SmallVectorImpl &results, - llvm::ArrayRef callStack) const { + fir::RecordType derived, llvm::SmallVectorImpl &results) { auto name = derived.getName(); auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name); - if (llvm::count(callStack, derived) > 1) { + + auto &callStack = getCurrentThreadRecursiveStack(); + if (llvm::count(callStack, derived)) { results.push_back(st); return mlir::success(); } + callStack.push_back(derived); + auto popConversionCallStack = + llvm::make_scope_exit([&callStack]() { callStack.pop_back(); }); + llvm::SmallVector members; for (auto mem : derived.getTypeList()) { // Prevent fir.box from degenerating to a pointer to a descriptor in the diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -162,6 +162,12 @@ /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; + // Recursive structure detection. + // We store one entry per thread here, and rely on locking. + DenseMap>> conversionCallStack; + llvm::sys::SmartRWMutex callStackMutex; + SmallVector &getCurrentThreadRecursiveStack(); + private: /// Convert a function type. The arguments and results are converted one by /// one. Additionally, if the function returns more than one value, pack the diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -307,7 +307,7 @@ /// types is empty, the type is removed and any usages of the existing value /// are expected to be removed during conversion. using ConversionCallbackFn = std::function( - Type, SmallVectorImpl &, ArrayRef)>; + Type, SmallVectorImpl &)>; /// The signature of the callback used to materialize a conversion. using MaterializationCallbackFn = std::function( @@ -330,44 +330,30 @@ template std::enable_if_t, ConversionCallbackFn> wrapCallback(FnT &&callback) const { - return wrapCallback( - [callback = std::forward(callback)]( - T type, SmallVectorImpl &results, ArrayRef) { - if (std::optional resultOpt = callback(type)) { - bool wasSuccess = static_cast(*resultOpt); - if (wasSuccess) - results.push_back(*resultOpt); - return std::optional(success(wasSuccess)); - } - return std::optional(); - }); + return wrapCallback([callback = std::forward(callback)]( + T type, SmallVectorImpl &results) { + if (std::optional resultOpt = callback(type)) { + bool wasSuccess = static_cast(*resultOpt); + if (wasSuccess) + results.push_back(*resultOpt); + return std::optional(success(wasSuccess)); + } + return std::optional(); + }); } /// With callback of form: `std::optional( - /// T, SmallVectorImpl &)`. + /// T, SmallVectorImpl &, ArrayRef)`. template std::enable_if_t &>, ConversionCallbackFn> - wrapCallback(FnT &&callback) const { - return wrapCallback( - [callback = std::forward(callback)]( - T type, SmallVectorImpl &results, ArrayRef) { - return callback(type, results); - }); - } - /// With callback of form: `std::optional( - /// T, SmallVectorImpl &, ArrayRef)`. - template - std::enable_if_t< - std::is_invocable_v &, ArrayRef>, - ConversionCallbackFn> wrapCallback(FnT &&callback) const { return [callback = std::forward(callback)]( - Type type, SmallVectorImpl &results, - ArrayRef callStack) -> std::optional { + Type type, + SmallVectorImpl &results) -> std::optional { T derivedType = dyn_cast(type); if (!derivedType) return std::nullopt; - return callback(derivedType, results, callStack); + return callback(derivedType, results); }; } @@ -435,10 +421,6 @@ mutable DenseMap cachedDirectConversions; /// This cache stores the successful 1->N conversions, where N != 1. mutable DenseMap> cachedMultiConversions; - - /// Stores the types that are being converted in the case when convertType - /// is being called recursively to convert nested types. - mutable SmallVector conversionCallStack; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -11,10 +11,34 @@ #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/Threading.h" +#include +#include #include using namespace mlir; +SmallVector &LLVMTypeConverter::getCurrentThreadRecursiveStack() { + { + // Most of the time, the entry already exists in the map. + std::shared_lock lock(callStackMutex, + std::defer_lock); + if (getContext().isMultithreadingEnabled()) + lock.lock(); + auto recursiveStack = conversionCallStack.find(llvm::get_threadid()); + if (recursiveStack != conversionCallStack.end()) + return *recursiveStack->second; + } + + // First time this thread gets here, we have to get an exclusive access to + // inset in the map + std::unique_lock lock(callStackMutex); + auto recursiveStackInserted = conversionCallStack.insert(std::make_pair( + llvm::get_threadid(), std::make_unique>())); + return *recursiveStackInserted.first->second.get(); +} + /// Create an LLVMTypeConverter using default LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis) @@ -56,8 +80,9 @@ return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace()); return std::nullopt; }); - addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl &results, - ArrayRef callStack) -> std::optional { + + addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl &results) + -> std::optional { // Fastpath for types that won't be converted by this callback anyway. if (LLVM::isCompatibleType(type)) { results.push_back(type); @@ -75,10 +100,15 @@ type.getContext(), ("_Converted_" + std::to_string(counter) + type.getName()).str()); } - if (llvm::count(callStack, type) > 1) { + + SmallVectorImpl &recursiveStack = getCurrentThreadRecursiveStack(); + if (llvm::count(recursiveStack, type)) { results.push_back(convertedType); return success(); } + recursiveStack.push_back(type); + auto popConversionCallStack = llvm::make_scope_exit( + [&recursiveStack]() { recursiveStack.pop_back(); }); SmallVector convertedElemTypes; convertedElemTypes.reserve(type.getBody().size()); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2935,12 +2935,9 @@ // Walk the added converters in reverse order to apply the most recently // registered first. size_t currentCount = results.size(); - conversionCallStack.push_back(t); - auto popConversionCallStack = - llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); }); + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional result = - converter(t, results, conversionCallStack)) { + if (std::optional result = converter(t, results)) { if (!succeeded(*result)) { cachedDirectConversions.try_emplace(t, nullptr); return failure(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ScopeExit.h" using namespace mlir; using namespace test; @@ -1374,6 +1375,7 @@ void runOnOperation() override { // Initialize the type converter. + SmallVector conversionCallStack; TypeConverter converter; /// Add the legal set of type conversions. @@ -1394,8 +1396,8 @@ converter.addConversion( // Convert a recursive self-referring type into a non-self-referring // type named "outer_converted_type" that contains a SimpleAType. - [&](test::TestRecursiveType type, SmallVectorImpl &results, - ArrayRef callStack) -> std::optional { + [&](test::TestRecursiveType type, + SmallVectorImpl &results) -> std::optional { // If the type is already converted, return it to indicate that it is // legal. if (type.getName() == "outer_converted_type") { @@ -1403,11 +1405,16 @@ return success(); } + conversionCallStack.push_back(type); + auto popConversionCallStack = llvm::make_scope_exit( + [&conversionCallStack]() { conversionCallStack.pop_back(); }); + // If the type is on the call stack more than once (it is there at // least once because of the _current_ call, which is always the last // element on the stack), we've hit the recursive case. Just return // SimpleAType here to create a non-recursive type as a result. - if (llvm::is_contained(callStack.drop_back(), type)) { + if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), + type)) { results.push_back(test::SimpleAType::get(type.getContext())); return success(); }