diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -924,10 +924,14 @@ LogicalResult ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (returnOp.getNumOperands()) { + if (returnOp.getNumOperands() > 1) return failure(); + + if (returnOp.getNumOperands() == 1) { + rewriter.replaceOpWithNewOp(returnOp, operands[0]); + } else { + rewriter.replaceOpWithNewOp(returnOp); } - rewriter.replaceOpWithNewOp(returnOp); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -473,23 +473,27 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); - // TODO: support converting functions with one result. - if (fnType.getNumResults()) + if (fnType.getNumResults() > 1) return failure(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - for (auto argType : enumerate(funcOp.getType().getInputs())) { + for (auto argType : enumerate(fnType.getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); if (!convertedType) return failure(); signatureConverter.addInputs(argType.index(), convertedType); } + Type resultType; + if (fnType.getNumResults() == 1) + resultType = typeConverter.convertType(fnType.getResult(0)); + // Create the converted spv.func op. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), - llvm::None)); + resultType ? TypeRange(resultType) + : TypeRange())); // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp.getAttrs()) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -954,3 +954,29 @@ } } // end module + +// ----- + +//===----------------------------------------------------------------------===// +// std.return +//===----------------------------------------------------------------------===// + +module attributes { + spv.target_env = #spv.target_env<#spv.vce, {}> +} { + +// CHECK-LABEL: spv.func @return_one_val +// CHECK-SAME: (%[[ARG:.+]]: f32) +func @return_one_val(%arg0: f32) -> f32 { + // CHECK: spv.ReturnValue %[[ARG]] : f32 + return %arg0: f32 +} + +// Check that multiple-return functions are not converted. +// CHECK-LABEL: func @return_multi_val +func @return_multi_val(%arg0: f32) -> (f32, f32) { + // CHECK: return + return %arg0, %arg0: f32, f32 +} + +}