Index: mlir/include/mlir/Transforms/DialectConversion.h =================================================================== --- mlir/include/mlir/Transforms/DialectConversion.h +++ mlir/include/mlir/Transforms/DialectConversion.h @@ -365,6 +365,12 @@ MLIRContext *ctx, TypeConverter &converter); +/// Add a pattern to the given pattern list to convert the result types of a +/// CallOp with the given type converter. +void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); + //===----------------------------------------------------------------------===// // Conversion PatternRewriter //===----------------------------------------------------------------------===// Index: mlir/lib/Transforms/DialectConversion.cpp =================================================================== --- mlir/lib/Transforms/DialectConversion.cpp +++ mlir/lib/Transforms/DialectConversion.cpp @@ -1754,6 +1754,46 @@ patterns.insert(ctx, converter); } +/// Create a default conversion pattern that rewrites the result type of a +/// CallOp +namespace { + +struct CallOpSignatureConversion : public OpConversionPattern { + CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : OpConversionPattern(ctx), converter(converter) {} + + /// Hook for derived classes to implement combined matching and rewriting. + PatternMatchResult + matchAndRewrite(CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FunctionType type = callOp.getCalleeType(); + + // Convert the original function results. + SmallVector convertedResults; + if (failed(converter.convertTypes(type.getResults(), convertedResults))) + return matchFailure(); + + // Substitute with the new result types from the corresponding FuncType + // conversion. + auto newCallOp = + rewriter.create(callOp.getLoc(), callOp.callee(), + convertedResults, callOp.getArgOperands()); + rewriter.replaceOp(callOp, newCallOp.getResults()); + return matchSuccess(); + } + + /// The type converter to use when rewriting the signature. + TypeConverter &converter; +}; + +} // end anonymous namespace + +void mlir::populateCallOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + patterns.insert(ctx, converter); +} + /// This function converts the type signature of the given block, by invoking /// 'convertSignatureArg' for each argument. This function should return a valid /// conversion for the signature on success, None otherwise.