diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -307,6 +307,14 @@ /// existing value are expected to be removed during conversion. If /// `llvm::None` is returned, the converter is allowed to try another /// conversion function to perform the conversion. + /// * Optional(T, SmallVectorImpl &, ArrayRef) + /// - This form represents a 1-N type conversion supporting recursive + /// types. The first two arguments and the return value are the same as + /// for the regular 1-N form. The third argument is contains is the + /// "call stack" of the recursive conversion: it contains the list of + /// types currently being converted, with the current type being the + /// last one. If it is present more than once in the list, the + /// conversion concerns a recursive type. /// Note: When attempting to convert a type, e.g. via 'convertType', the /// mostly recently added conversions will be invoked first. template (T, SmallVectorImpl &, ArrayRef) + /// - This form represents a 1-N type conversion supporting recursive + /// types. The first two arguments and the return value are the same as + /// for the regular 1-N form. The third argument is contains is the + /// "call stack" of the recursive conversion: it contains the list of + /// types currently being converted, with the current type being the + /// last one. If it is present more than once in the list, the + /// conversion concerns a recursive type. /// Note: When attempting to convert a type, e.g. via 'convertType', the /// mostly recently added conversions will be invoked first. template (Type, SmallVectorImpl &)>; + using ConversionCallbackFn = std::function( + Type, SmallVectorImpl &, ArrayRef)>; /// The signature of the callback used to materialize a conversion. using MaterializationCallbackFn = @@ -240,28 +248,44 @@ template std::enable_if_t::value, ConversionCallbackFn> wrapCallback(FnT &&callback) { - return wrapCallback([callback = std::forward(callback)]( - T type, SmallVectorImpl &results) { - if (Optional resultOpt = callback(type)) { - bool wasSuccess = static_cast(resultOpt.getValue()); - if (wasSuccess) - results.push_back(resultOpt.getValue()); - return Optional(success(wasSuccess)); - } - return Optional(); - }); - } - /// With callback of form: `Optional(T, SmallVectorImpl<> &)` + return wrapCallback( + [callback = std::forward(callback)]( + T type, SmallVectorImpl &results, ArrayRef) { + if (Optional resultOpt = callback(type)) { + bool wasSuccess = static_cast(resultOpt.getValue()); + if (wasSuccess) + results.push_back(resultOpt.getValue()); + return Optional(success(wasSuccess)); + } + return Optional(); + }); + } + /// With callback of form: `Optional(T, SmallVectorImpl + /// &)` template - std::enable_if_t::value, ConversionCallbackFn> + std::enable_if_t &>::value, + ConversionCallbackFn> + wrapCallback(FnT &&callback) { + return wrapCallback( + [callback = std::forward(callback)]( + T type, SmallVectorImpl &results, ArrayRef) { + return callback(type, results); + }); + } + /// With callback of form: `Optional(T, SmallVectorImpl + /// &, ArrayRef)`. + template + std::enable_if_t &, + ArrayRef>::value, + ConversionCallbackFn> wrapCallback(FnT &&callback) { return [callback = std::forward(callback)]( - Type type, - SmallVectorImpl &results) -> Optional { + Type type, SmallVectorImpl &results, + ArrayRef callStack) -> Optional { T derivedType = type.dyn_cast(); if (!derivedType) return llvm::None; - return callback(derivedType, results); + return callback(derivedType, results, callStack); }; } @@ -300,6 +324,10 @@ DenseMap cachedDirectConversions; /// This cache stores the successful 1->N conversions, where N != 1. DenseMap> cachedMultiConversions; + + /// Stores the types that are being converted in the case when convertType + /// is being called recursively to convert nested types. + SmallVector conversionCallStack; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/Utils.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" @@ -2931,8 +2932,12 @@ // Walk the added converters in reverse order to apply the most recently // registered first. size_t currentCount = results.size(); + conversionCallStack.push_back(t); + auto popConversionCallStack = + llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); }); for (ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (Optional result = converter(t, results)) { + if (Optional result = + converter(t, results, conversionCallStack)) { if (!succeeded(*result)) { cachedDirectConversions.try_emplace(t, nullptr); return failure(); diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -112,3 +112,12 @@ }) : () -> () return } + +// ----- + +// CHECK-LABEL: @recursive_type_conversion +func @recursive_type_conversion() { + // CHECK: !test.test_rec + "test.type_producer"() : () -> !test.test_rec> + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestTypes.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" @@ -924,10 +925,16 @@ matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Type resultType = op.getType(); + Type convertedType = getTypeConverter() + ? getTypeConverter()->convertType(resultType) + : resultType; if (resultType.isa()) resultType = rewriter.getF64Type(); else if (resultType.isInteger(16)) resultType = rewriter.getIntegerType(64); + else if (resultType.isa() && + convertedType != resultType) + resultType = convertedType; else return failure(); @@ -1035,6 +1042,35 @@ // Drop all integer types. return success(); }); + converter.addConversion( + // Convert a recursive self-referring type into a non-self-referring + // type named "outer_converted_type" that contains a SimpleAType. + [&](test::TestRecursiveType type, SmallVectorImpl &results, + ArrayRef callStack) -> Optional { + // If the type is already converted, return it to indicate that it is + // legal. + if (type.getName() == "outer_converted_type") { + results.push_back(type); + return success(); + } + + // If the type is on the call stack more than once (it is there at + // least once because of the _current_ call), we've hit the recursive + // case. Just return SimpleAType here to create a non-recursive type + // as a result. + if (llvm::count(callStack, type) > 1) { + results.push_back(test::SimpleAType::get(type.getContext())); + return success(); + } + + // Convert the body recursively. + auto result = test::TestRecursiveType::get(type.getContext(), + "outer_converted_type"); + if (failed(result.setBody(converter.convertType(type.getBody())))) + return failure(); + results.push_back(result); + return success(); + }); /// Add the legal set of type materializations. converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, @@ -1059,7 +1095,10 @@ // Initialize the conversion target. mlir::ConversionTarget target(getContext()); target.addDynamicallyLegalOp([](TestTypeProducerOp op) { - return op.getType().isF64() || op.getType().isInteger(64); + auto recursiveType = op.getType().dyn_cast(); + return op.getType().isF64() || op.getType().isInteger(64) || + (recursiveType && + recursiveType.getName() == "outer_converted_type"); }); target.addDynamicallyLegalOp([&](FuncOp op) { return converter.isSignatureLegal(op.getType()) &&