diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -353,9 +353,7 @@ return OpBuilder::atBlockEnd(module.getBody()) .create(loc, functionName, functionType); }(); - return builder.create( - loc, const_cast(functionType).getReturnType(), - builder.getSymbolRefAttr(function), arguments); + return builder.create(loc, function, arguments); } // Returns whether all operands are of LLVM type. diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -248,7 +248,7 @@ } // Create call to `bindMemRef`. builder.create( - loc, TypeRange{getVoidType()}, + loc, TypeRange(), builder.getSymbolRefAttr( StringRef(symbolName.data(), symbolName.size())), ValueRange{vulkanRuntime, descriptorSet, descriptorBinding, @@ -396,32 +396,31 @@ // Create call to `setBinaryShader` runtime function with the given pointer to // SPIR-V binary and binary size. builder.create( - loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader), + loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader), ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize}); // Create LLVM global with entry point name. Value entryPointName = createEntryPointNameConstant( spirvAttributes.second.getValue(), loc, builder); // Create call to `setEntryPoint` runtime function with the given pointer to // entry point name. - builder.create(loc, TypeRange{getVoidType()}, + builder.create(loc, TypeRange(), builder.getSymbolRefAttr(kSetEntryPoint), ValueRange{vulkanRuntime, entryPointName}); // Create number of local workgroup for each dimension. builder.create( - loc, TypeRange{getVoidType()}, - builder.getSymbolRefAttr(kSetNumWorkGroups), + loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups), ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), cInterfaceVulkanLaunchCallOp.getOperand(1), cInterfaceVulkanLaunchCallOp.getOperand(2)}); // Create call to `runOnVulkan` runtime function. - builder.create(loc, TypeRange{getVoidType()}, + builder.create(loc, TypeRange(), builder.getSymbolRefAttr(kRunOnVulkan), ValueRange{vulkanRuntime}); // Create call to 'deinitVulkan' runtime function. - builder.create(loc, TypeRange{getVoidType()}, + builder.create(loc, TypeRange(), builder.getSymbolRefAttr(kDeinitVulkan), ValueRange{vulkanRuntime}); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -815,6 +815,19 @@ << ": " << op.getOperand(i + isIndirect).getType() << " != " << funcType.getParamType(i); + if (op.getNumResults() == 0 && + !funcType.getReturnType().isa()) + return op.emitOpError() << "expected function call to produce a value"; + + if (op.getNumResults() != 0 && + funcType.getReturnType().isa()) + return op.emitOpError() + << "calling function with void result must not produce values"; + + if (op.getNumResults() > 1) + return op.emitOpError() + << "expected LLVM function call to produce 0 or 1 result"; + if (op.getNumResults() && op.getResult(0).getType() != funcType.getReturnType()) return op.emitOpError() @@ -874,19 +887,18 @@ auto funcType = type.dyn_cast(); if (!funcType) return parser.emitError(trailingTypeLoc, "expected function type"); + if (funcType.getNumResults() > 1) + return parser.emitError(trailingTypeLoc, + "expected function with 0 or 1 result"); if (isDirect) { // Make sure types match. if (parser.resolveOperands(operands, funcType.getInputs(), parser.getNameLoc(), result.operands)) return failure(); - result.addTypes(funcType.getResults()); + if (funcType.getNumResults() != 0 && + !funcType.getResult(0).isa()) + result.addTypes(funcType.getResults()); } else { - // Construct the LLVM IR Dialect function type that the first operand - // should match. - if (funcType.getNumResults() > 1) - return parser.emitError(trailingTypeLoc, - "expected function with 0 or 1 result"); - Builder &builder = parser.getBuilder(); Type llvmResultType; if (funcType.getNumResults() == 0) { @@ -921,7 +933,8 @@ parser.getNameLoc(), result.operands)) return failure(); - result.addTypes(llvmResultType); + if (!llvmResultType.isa()) + result.addTypes(llvmResultType); } return success(); diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir --- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir +++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir @@ -6,14 +6,14 @@ // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]] // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant -// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) -> !llvm.void -// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> !llvm.void +// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) -> () +// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> () // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]] -// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.void -// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> !llvm.void -// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> !llvm.void -// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> !llvm.void +// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> () +// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> () +// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> () +// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> () // CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1089,3 +1089,33 @@ %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)> llvm.return } + +// ----- + +llvm.func @caller() { + // expected-error @below {{expected function call to produce a value}} + llvm.call @callee() : () -> () + llvm.return +} + +llvm.func @callee() -> i32 + +// ----- + +llvm.func @caller() { + // expected-error @below {{calling function with void result must not produce values}} + %0 = llvm.call @callee() : () -> i32 + llvm.return +} + +llvm.func @callee() -> () + +// ----- + +llvm.func @caller() { + // expected-error @below {{expected function with 0 or 1 result}} + %0:2 = llvm.call @callee() : () -> (i32, f32) + llvm.return +} + +llvm.func @callee() -> !llvm.struct<(i32, f32)>