diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -421,12 +421,6 @@ using ConversionPattern::matchAndRewrite; }; -/// Add a pattern to the given pattern list to convert the signature of a FuncOp -/// with the given type converter. -void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns, - MLIRContext *ctx, - TypeConverter &converter); - //===----------------------------------------------------------------------===// // Conversion PatternRewriter //===----------------------------------------------------------------------===// @@ -796,6 +790,61 @@ MLIRContext &ctx; }; +//===----------------------------------------------------------------------===// +// Function Signature Conversion +//===----------------------------------------------------------------------===// + +namespace impl { +/// Attempts to convert a FunctionLike op's input and output types using the +/// given type converter. If that succeeds, converts the region's types using +/// the given pattern rewriter, type converter, and the signature conversion +/// result. +FailureOr +convertFunctionType(FunctionType type, Region ®ion, + TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); +} // namespace impl + +/// Create a default conversion pattern that rewrites the type signature of a +/// FunctionLike op. +template +struct FunctionLikeSignatureConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + + /// Hook to implement combined matching and rewriting for FunctionLike ops. + LogicalResult + matchAndRewrite(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Attempt to convert the original function types. + FailureOr newType = ::mlir::impl::convertFunctionType( + op.getType(), op.getBody(), *getTypeConverter(), rewriter); + + // If successful, update the function signature in-place. + if (failed(newType)) + return failure(); + + rewriter.updateRootInPlace(op, [&] { op.setType(newType.getValue()); }); + + return success(); + } +}; + +/// Add a pattern to the given pattern list to convert the signature of a +/// FunctionLike op with the given type converter. +template +void populateFunctionLikeTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + patterns.insert>(converter, ctx); +} + +/// Add a pattern to the given pattern list to convert the signature of a FuncOp +/// with the given type converter. +void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SetVector.h" @@ -2514,44 +2515,6 @@ return conversion; } -/// Create a default conversion pattern that rewrites the type signature of a -/// FuncOp. -namespace { -struct FuncOpSignatureConversion : public OpConversionPattern { - FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(converter, ctx) {} - - /// Hook for derived classes to implement combined matching and rewriting. - LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - FunctionType type = funcOp.getType(); - - // Convert the original function types. - TypeConverter::SignatureConversion result(type.getNumInputs()); - SmallVector newResults; - if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || - failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, - &result))) - return failure(); - - // Update the function signature in-place. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(FunctionType::get(funcOp.getContext(), - result.getConvertedTypes(), newResults)); - }); - return success(); - } -}; -} // end anonymous namespace - -void mlir::populateFuncOpTypeConversionPattern( - OwningRewritePatternList &patterns, MLIRContext *ctx, - TypeConverter &converter) { - patterns.insert(ctx, converter); -} - //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// @@ -2669,6 +2632,31 @@ return llvm::None; } +//===----------------------------------------------------------------------===// +// Function Signature Conversion +//===----------------------------------------------------------------------===// + +FailureOr +mlir::impl::convertFunctionType(FunctionType type, Region ®ion, + TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(®ion, typeConverter, &result))) + return failure(); + + return FunctionType::get(rewriter.getContext(), result.getConvertedTypes(), + newResults); +} + +void mlir::populateFuncOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + populateFunctionLikeTypeConversionPattern(patterns, ctx, converter); +} + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===//