diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -507,6 +507,9 @@ patterns, converter); } +void populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter); + //===----------------------------------------------------------------------===// // Conversion PatternRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3090,6 +3090,34 @@ return success(); } }; + +struct AnyFunctionOpInterfaceSignatureConversion + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern::OpInterfaceConversionPattern; + + LogicalResult + matchAndRewrite(FunctionOpInterface funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FunctionType type = funcOp.getFunctionType().cast(); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter->convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + *typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); }); + + return success(); + } +}; } // namespace void mlir::populateFunctionOpInterfaceTypeConversionPattern( @@ -3099,6 +3127,12 @@ functionLikeOpName, patterns.getContext(), converter); } +void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter) { + patterns.add( + converter, patterns.getContext()); +} + //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -786,8 +786,8 @@ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, + converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); // Define the conversion target used for the test. @@ -1313,8 +1313,8 @@ TestTestSignatureConversionNoConverter>(converter, &getContext()); patterns.add(&getContext()); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, + converter); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))