diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -170,6 +170,9 @@ LLVM::LLVMDialect *llvmDialect; private: + // Recursive structure detection + SmallVector conversionCallStack; + /// Convert a function type. The arguments and results are converted one by /// one. Additionally, if the function returns more than one value, pack the /// results into an LLVM IR structure type so that the converted function type diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -307,7 +307,7 @@ /// types is empty, the type is removed and any usages of the existing value /// are expected to be removed during conversion. using ConversionCallbackFn = std::function( - Type, SmallVectorImpl &, ArrayRef)>; + Type, SmallVectorImpl &)>; /// The signature of the callback used to materialize a conversion. using MaterializationCallbackFn = std::function( @@ -330,44 +330,30 @@ template std::enable_if_t, ConversionCallbackFn> wrapCallback(FnT &&callback) const { - return wrapCallback( - [callback = std::forward(callback)]( - T type, SmallVectorImpl &results, ArrayRef) { - if (std::optional resultOpt = callback(type)) { - bool wasSuccess = static_cast(*resultOpt); - if (wasSuccess) - results.push_back(*resultOpt); - return std::optional(success(wasSuccess)); - } - return std::optional(); - }); + return wrapCallback([callback = std::forward(callback)]( + T type, SmallVectorImpl &results) { + if (std::optional resultOpt = callback(type)) { + bool wasSuccess = static_cast(*resultOpt); + if (wasSuccess) + results.push_back(*resultOpt); + return std::optional(success(wasSuccess)); + } + return std::optional(); + }); } /// With callback of form: `std::optional( - /// T, SmallVectorImpl &)`. + /// T, SmallVectorImpl &, ArrayRef)`. template std::enable_if_t &>, ConversionCallbackFn> - wrapCallback(FnT &&callback) const { - return wrapCallback( - [callback = std::forward(callback)]( - T type, SmallVectorImpl &results, ArrayRef) { - return callback(type, results); - }); - } - /// With callback of form: `std::optional( - /// T, SmallVectorImpl &, ArrayRef)`. - template - std::enable_if_t< - std::is_invocable_v &, ArrayRef>, - ConversionCallbackFn> wrapCallback(FnT &&callback) const { return [callback = std::forward(callback)]( - Type type, SmallVectorImpl &results, - ArrayRef callStack) -> std::optional { + Type type, + SmallVectorImpl &results) -> std::optional { T derivedType = dyn_cast(type); if (!derivedType) return std::nullopt; - return callback(derivedType, results, callStack); + return callback(derivedType, results); }; } @@ -435,10 +421,6 @@ mutable DenseMap cachedDirectConversions; /// This cache stores the successful 1->N conversions, where N != 1. mutable DenseMap> cachedMultiConversions; - - /// Stores the types that are being converted in the case when convertType - /// is being called recursively to convert nested types. - mutable SmallVector conversionCallStack; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -11,6 +11,7 @@ #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "llvm/ADT/ScopeExit.h" #include using namespace mlir; @@ -56,13 +57,16 @@ return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace()); return std::nullopt; }); - addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl &results, - ArrayRef callStack) -> std::optional { + addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl &results) + -> std::optional { // Fastpath for types that won't be converted by this callback anyway. if (LLVM::isCompatibleType(type)) { results.push_back(type); return success(); } + conversionCallStack.push_back(type); + auto popConversionCallStack = + llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); }); if (type.isIdentified()) { auto convertedType = LLVM::LLVMStructType::getIdentified( @@ -75,7 +79,7 @@ type.getContext(), ("_Converted_" + std::to_string(counter) + type.getName()).str()); } - if (llvm::count(callStack, type) > 1) { + if (llvm::count(conversionCallStack, type) > 1) { results.push_back(convertedType); return success(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2926,12 +2926,9 @@ // Walk the added converters in reverse order to apply the most recently // registered first. size_t currentCount = results.size(); - conversionCallStack.push_back(t); - auto popConversionCallStack = - llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); }); + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional result = - converter(t, results, conversionCallStack)) { + if (std::optional result = converter(t, results)) { if (!succeeded(*result)) { cachedDirectConversions.try_emplace(t, nullptr); return failure(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ScopeExit.h" using namespace mlir; using namespace test; @@ -1374,6 +1375,7 @@ void runOnOperation() override { // Initialize the type converter. + SmallVector conversionCallStack; TypeConverter converter; /// Add the legal set of type conversions. @@ -1394,8 +1396,8 @@ converter.addConversion( // Convert a recursive self-referring type into a non-self-referring // type named "outer_converted_type" that contains a SimpleAType. - [&](test::TestRecursiveType type, SmallVectorImpl &results, - ArrayRef callStack) -> std::optional { + [&](test::TestRecursiveType type, + SmallVectorImpl &results) -> std::optional { // If the type is already converted, return it to indicate that it is // legal. if (type.getName() == "outer_converted_type") { @@ -1403,11 +1405,16 @@ return success(); } + conversionCallStack.push_back(type); + auto popConversionCallStack = llvm::make_scope_exit( + [&conversionCallStack]() { conversionCallStack.pop_back(); }); + // If the type is on the call stack more than once (it is there at // least once because of the _current_ call, which is always the last // element on the stack), we've hit the recursive case. Just return // SimpleAType here to create a non-recursive type as a result. - if (llvm::is_contained(callStack.drop_back(), type)) { + if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), + type)) { results.push_back(test::SimpleAType::get(type.getContext())); return success(); }