diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp @@ -57,6 +57,32 @@ } }; +/// Converts func.call to spv.FunctionCall. +class CallOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // multiple results func was not converted to spv.func + if (callOp.getNumResults() > 1) + return failure(); + if (callOp.getNumResults() == 1) { + auto resultType = + getTypeConverter()->convertType(callOp.getResult(0).getType()); + if (!resultType) + return failure(); + rewriter.replaceOpWithNewOp( + callOp, resultType, adaptor.getOperands(), callOp->getAttrs()); + } else { + rewriter.replaceOpWithNewOp( + callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs()); + } + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -67,5 +93,5 @@ RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); } diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir @@ -8,6 +8,12 @@ spv.target_env = #spv.target_env<#spv.vce, {}> } { +// CHECK-LABEL: spv.func @return_none_val +func @return_none_val() { + // CHECK: spv.Return + return +} + // CHECK-LABEL: spv.func @return_one_val // CHECK-SAME: (%[[ARG:.+]]: f32) func @return_one_val(%arg0: f32) -> f32 { @@ -22,6 +28,24 @@ return %arg0, %arg0: f32, f32 } +// CHECK-LABEL: spv.func @return_one_index +// CHECK-SAME: (%[[ARG:.+]]: i32) +func @return_one_index(%arg0: index) -> index { + // CHECK: spv.ReturnValue %[[ARG]] : i32 + return %arg0: index +} + +// CHECK-LABEL: spv.func @call_functions +// CHECK-SAME: (%[[ARG:.+]]: i32) +func @call_functions(%arg0: index) -> index { + // CHECK: spv.FunctionCall @return_none_val() : () -> () + call @return_none_val(): () -> () + // CHECK: {{%.*}} = spv.FunctionCall @return_one_index(%[[ARG]]) : (i32) -> i32 + %0 = call @return_one_index(%arg0): (index) -> index + // CHECK: spv.ReturnValue {{%.*}} : i32 + return %0: index +} + } // -----