diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -507,6 +507,9 @@ patterns, converter); } +void populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter); + //===----------------------------------------------------------------------===// // Conversion PatternRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3056,6 +3056,29 @@ // FunctionOpInterfaceSignatureConversion //===----------------------------------------------------------------------===// +static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, + TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + FunctionType type = funcOp.getFunctionType().cast(); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); }); + + return success(); +} + /// Create a default conversion pattern that rewrites the type signature of a /// FunctionOpInterface op. This only supports ops which use FunctionType to /// represent their type. @@ -3067,27 +3090,21 @@ : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Operation *op, ArrayRef /*operands*/, ConversionPatternRewriter &rewriter) const override { FunctionOpInterface funcOp = cast(op); - FunctionType type = funcOp.getFunctionType().cast(); - - // Convert the original function types. - TypeConverter::SignatureConversion result(type.getNumInputs()); - SmallVector newResults; - if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || - failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), - *typeConverter, &result))) - return failure(); - - // Update the function signature in-place. - auto newType = FunctionType::get(rewriter.getContext(), - result.getConvertedTypes(), newResults); + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; - rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); }); +struct AnyFunctionOpInterfaceSignatureConversion + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern::OpInterfaceConversionPattern; - return success(); + LogicalResult + matchAndRewrite(FunctionOpInterface funcOp, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); } }; } // namespace @@ -3099,6 +3116,12 @@ functionLikeOpName, patterns.getContext(), converter); } +void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter) { + patterns.add( + converter, patterns.getContext()); +} + //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -786,8 +786,8 @@ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, + converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); // Define the conversion target used for the test. @@ -1313,8 +1313,8 @@ TestTestSignatureConversionNoConverter>(converter, &getContext()); patterns.add(&getContext()); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, + converter); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))