diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp @@ -57,6 +57,32 @@ } }; +/// Converts func.call to spv.FunctionCall. +class CallOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // multiple results func was not converted to spv.func + if (callOp.getNumResults() > 1) + return failure(); + if (callOp.getNumResults() == 1) { + auto resultType = + getTypeConverter()->convertType(callOp.getResult(0).getType()); + if (!resultType) + return failure(); + rewriter.replaceOpWithNewOp( + callOp, resultType, adaptor.getOperands(), callOp->getAttrs()); + } else { + rewriter.replaceOpWithNewOp( + callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs()); + } + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -67,5 +93,5 @@ RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); } diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir @@ -22,6 +22,22 @@ return %arg0, %arg0: f32, f32 } +// CHECK-LABEL: spv.func @return_one_index +// CHECK-SAME: (%[[ARG:.+]]: i32) +func @return_one_index(%arg0: index) -> index { + // CHECK: spv.ReturnValue %[[ARG]] : i32 + return %arg0: index +} + +// CHECK-LABEL: spv.func @call_return_one_index +// CHECK-SAME: (%[[ARG:.+]]: i32) +func @call_return_one_index(%arg0: index) -> index { + // CHECK: {{%.*}} = spv.FunctionCall @return_one_index(%[[ARG]]) : (i32) -> i32 + %0 = call @return_one_index(%arg0): (index) -> index + // CHECK: spv.ReturnValue {{%.*}} : i32 + return %0: index +} + } // -----