diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -248,6 +248,28 @@ } }; +class FunctionCallPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (callOp.getNumResults() == 0) { + rewriter.replaceOpWithNewOp(callOp, llvm::None, operands, + callOp.getAttrs()); + return success(); + } + + // Function returns a single result. + auto dstType = this->typeConverter.convertType(callOp.getType(0)); + rewriter.replaceOpWithNewOp(callOp, dstType, operands, + callOp.getAttrs()); + return success(); + } +}; + /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" template class FComparePattern : public SPIRVToLLVMConversion { @@ -551,6 +573,9 @@ IComparePattern, IComparePattern, + // Function Call op + FunctionCallPattern, + // Logical ops DirectConversionPattern, DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/func-to-llvm.mlir @@ -58,5 +58,36 @@ spv.ReturnValue %0 : vector<2xi64> } +//===----------------------------------------------------------------------===// +// spv.FunctionCall +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: llvm.func @function_calls +// CHECK-SAME: %[[ARG0:.*]]: !llvm.i32, %[[ARG1:.*]]: !llvm.i1, %[[ARG2:.*]]: !llvm.double, %[[ARG3:.*]]: !llvm<"<2 x i64>">, %[[ARG4:.*]]: !llvm<"<2 x float>"> +spv.func @function_calls(%arg0: i32, %arg1: i1, %arg2: f64, %arg3: vector<2xi64>, %arg4: vector<2xf32>) -> () "None" { + // CHECK: llvm.call @void_1() : () -> () + spv.FunctionCall @void_1() : () -> () + // CHECK: llvm.call @void_2(%[[ARG3]]) : (!llvm<"<2 x i64>">) -> () + spv.FunctionCall @void_2(%arg3) : (vector<2xi64>) -> () + // CHECK: %{{.*}} = llvm.call @value_scalar(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!llvm.i32, !llvm.i1, !llvm.double) -> !llvm.i32 + %0 = spv.FunctionCall @value_scalar(%arg0, %arg1, %arg2) : (i32, i1, f64) -> i32 + // CHECK: %{{.*}} = llvm.call @value_vector(%[[ARG3]], %[[ARG4]]) : (!llvm<"<2 x i64>">, !llvm<"<2 x float>">) -> !llvm<"<2 x float>"> + %1 = spv.FunctionCall @value_vector(%arg3, %arg4) : (vector<2xi64>, vector<2xf32>) -> vector<2xf32> + spv.Return +} +spv.func @void_1() -> () "None" { + spv.Return +} +spv.func @void_2(%arg0: vector<2xi64>) -> () "None" { + spv.Return +} + +spv.func @value_scalar(%arg0: i32, %arg1: i1, %arg2: f64) -> i32 "None" { + spv.ReturnValue %arg0: i32 +} + +spv.func @value_vector(%arg0: vector<2xi64>, %arg1: vector<2xf32>) -> vector<2xf32> "None" { + spv.ReturnValue %arg1: vector<2xf32> +}